rxnn 0.1.14__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {rxnn-0.1.14 → rxnn-0.1.15}/PKG-INFO +1 -1
  2. {rxnn-0.1.14 → rxnn-0.1.15}/pyproject.toml +1 -1
  3. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/moe.py +33 -23
  4. {rxnn-0.1.14 → rxnn-0.1.15}/LICENSE +0 -0
  5. {rxnn-0.1.14 → rxnn-0.1.15}/README.md +0 -0
  6. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/__init__.py +0 -0
  7. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/experimental/__init__.py +0 -0
  8. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/experimental/attention.py +0 -0
  9. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/memory/__init__.py +0 -0
  10. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/memory/norm.py +0 -0
  11. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/memory/stm.py +0 -0
  12. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/rxt/__init__.py +0 -0
  13. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/rxt/models.py +0 -0
  14. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/__init__.py +0 -0
  15. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/base.py +0 -0
  16. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/bml.py +0 -0
  17. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/callbacks.py +0 -0
  18. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/dataset.py +0 -0
  19. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/scheduler.py +0 -0
  20. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/training/tokenizer.py +0 -0
  21. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/__init__.py +0 -0
  22. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/attention.py +0 -0
  23. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/ff.py +0 -0
  24. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/layers.py +0 -0
  25. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/mask.py +0 -0
  26. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/models.py +0 -0
  27. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/positional.py +0 -0
  28. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/transformers/sampler.py +0 -0
  29. {rxnn-0.1.14 → rxnn-0.1.15}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.14
3
+ Version: 0.1.15
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.14"
7
+ version = "0.1.15"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -106,34 +106,44 @@ class MoeFeedForward(nn.Module):
106
106
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
107
107
 
108
108
  # Get routing weights and indices
109
- weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
109
+ weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
110
110
 
111
111
  # Flatten indices and weights
112
- batch_size = x.size(0)
113
- top_k = indices.size(1)
114
- indices = indices.view(-1) # [batch*seq_len * top_k]
115
- weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
116
-
117
- # Select only the relevant experts for each token
118
- selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
119
- selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
120
- selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
121
- selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
122
-
123
- # Reshape x for batched computation
124
- x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
125
-
126
- # Compute only the selected experts
127
- h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
112
+ batch_size = x.shape[0]
113
+ top_k = indices.shape[1]
114
+ indices_flat = indices.view(-1) # [B*T * top_k]
115
+
116
+ # Compute contributions for selected experts without materializing large tensors
117
+ # First Layer:
118
+ # Compute all expert contributions first (but this may still be memory-heavy)
119
+ # Alternative: Compute contributions for selected experts directly
120
+ # ... (see detailed steps below)
121
+
122
+ # Alternative approach using gather and batched operations
123
+ x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
124
+
125
+ # Compute first layer contributions using gather
126
+ # indices_flat has shape [B*T*top_k]
127
+ # selected_w1 is self.w1[indices_flat], but we compute the product inline
128
+ h = torch.einsum(
129
+ 'be, eih -> bh',
130
+ x_expanded,
131
+ self.w1[indices_flat]
132
+ ) + self.b1[indices_flat]
128
133
  h = self._activate(h)
129
134
  h = self.dropout(h)
130
135
 
131
- out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
132
-
133
- # Reshape back and apply weights
134
- out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
135
- weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
136
- out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
136
+ # Second layer:
137
+ out = torch.einsum(
138
+ 'bh, eho -> beo',
139
+ h,
140
+ self.w2[indices_flat]
141
+ ).squeeze(-1) + self.b2[indices_flat]
142
+
143
+ # Reshape and apply weights
144
+ out = out.view(batch_size, top_k, -1)
145
+ weights = weights.view(batch_size, top_k, 1)
146
+ out = (out * weights).sum(dim=1)
137
147
 
138
148
  return out.view(*orig_shape)
139
149
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes