rxnn 0.1.13__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {rxnn-0.1.13 → rxnn-0.1.14}/PKG-INFO +1 -1
  2. {rxnn-0.1.13 → rxnn-0.1.14}/pyproject.toml +1 -1
  3. rxnn-0.1.14/src/rxnn/experimental/attention.py +380 -0
  4. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/moe.py +46 -12
  5. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/sampler.py +82 -22
  6. rxnn-0.1.13/src/rxnn/experimental/attention.py +0 -133
  7. {rxnn-0.1.13 → rxnn-0.1.14}/LICENSE +0 -0
  8. {rxnn-0.1.13 → rxnn-0.1.14}/README.md +0 -0
  9. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/memory/norm.py +0 -0
  13. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/memory/stm.py +0 -0
  14. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/rxt/__init__.py +0 -0
  15. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/rxt/models.py +0 -0
  16. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/__init__.py +0 -0
  17. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/base.py +0 -0
  18. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/bml.py +0 -0
  19. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/callbacks.py +0 -0
  20. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/dataset.py +0 -0
  21. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/attention.py +0 -0
  25. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/ff.py +0 -0
  26. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/layers.py +0 -0
  27. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/mask.py +0 -0
  28. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/models.py +0 -0
  29. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/transformers/positional.py +0 -0
  30. {rxnn-0.1.13 → rxnn-0.1.14}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.13"
7
+ version = "0.1.14"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -0,0 +1,380 @@
1
+ import torch
2
+ from torch import nn
3
+ from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
4
+ from ..transformers.positional import RotaryPositionalEmbedding
5
+ from ..transformers.moe import MoeRouter
6
+
7
+ # Created by Reactive AI
8
+
9
+ class GroupedMoeAttention(GroupedQueryAttention):
10
+ """
11
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
12
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
13
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
14
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
15
+
16
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
17
+ this approach - we are training the full number of keys/values heads, while using only a group.
18
+
19
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
20
+
21
+ Optionally, it could use even more expert heads than attention heads - in example:
22
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
23
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
24
+ """
25
+ def __init__(
26
+ self,
27
+ embed_dim: int,
28
+ num_heads: int,
29
+ num_groups: int,
30
+ dropout: float = 0.0,
31
+ rope: RotaryPositionalEmbedding = None,
32
+ rope_only_for_query: bool = False,
33
+ use_relative_embeddings: bool = False,
34
+ max_seq_len: int = 1024,
35
+ use_flash_attention: bool = False,
36
+ is_causal: bool = False,
37
+ use_bias: bool = False,
38
+ num_experts: int = None,
39
+ *args,
40
+ **kwargs,
41
+ ):
42
+ self.num_experts = num_experts if num_experts is not None else num_heads
43
+ super(GroupedMoeAttention, self).__init__(
44
+ embed_dim,
45
+ num_heads,
46
+ num_groups=num_groups,
47
+ dropout=dropout,
48
+ rope=rope,
49
+ rope_only_for_query=rope_only_for_query,
50
+ use_relative_embeddings=use_relative_embeddings,
51
+ max_seq_len=max_seq_len,
52
+ use_flash_attention=use_flash_attention,
53
+ is_causal=is_causal,
54
+ use_bias=use_bias,
55
+ *args,
56
+ **kwargs,
57
+ )
58
+
59
+ def _init_kv(self, embed_dim: int):
60
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
61
+ hidden_dim = embed_dim // (self.num_heads // self.num_groups)
62
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
63
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
64
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
65
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
66
+ self._init_experts()
67
+
68
+ def _init_experts(self):
69
+ torch.nn.init.xavier_uniform_(self.wk)
70
+ torch.nn.init.xavier_uniform_(self.wv)
71
+ if self.use_bias:
72
+ torch.nn.init.zeros_(self.bk)
73
+ torch.nn.init.zeros_(self.bv)
74
+
75
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
76
+ # head_dim = d // self.num_heads
77
+ # group_heads = self.num_heads // self.num_groups
78
+ #
79
+ # # Process Query as in GQA
80
+ # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
81
+ #
82
+ # # Process Key and Value with MoE routing
83
+ # key_flat = key.view(-1, d)
84
+ # weights, indices = self.router(key_flat)
85
+ # weights = weights.view(b, key.size(1), self.num_groups, 1)
86
+ # indices = indices.view(b, key.size(1), self.num_groups)
87
+ #
88
+ # # Compute all experts' K and V projections
89
+ # # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
90
+ # k_all = torch.einsum(
91
+ # 'be, ehd -> bedh',
92
+ # key_flat,
93
+ # self.wk.view(self.num_experts, d, -1)
94
+ # ).view(b, key.size(1), self.num_experts, -1)
95
+ #
96
+ # v_all = torch.einsum(
97
+ # 'be, ehd -> bedh',
98
+ # value.view(-1, d),
99
+ # self.wv.view(self.num_experts, d, -1)
100
+ # ).view(b, value.size(1), self.num_experts, -1)
101
+ #
102
+ # # Select top_k experts and compute weighted sum
103
+ # selected_k = torch.gather(
104
+ # k_all,
105
+ # 2,
106
+ # indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
107
+ # )
108
+ # selected_v = torch.gather(
109
+ # v_all,
110
+ # 2,
111
+ # indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
112
+ # )
113
+ #
114
+ # selected_k = (selected_k * weights).sum(dim=2)
115
+ # selected_v = (selected_v * weights).sum(dim=2)
116
+ # # Reshape to GQA format: (B, G, S, head_dim)
117
+ # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
118
+ # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
119
+ #
120
+ # if not self.use_flash_attention:
121
+ # group_heads = self.num_heads // self.num_groups
122
+ #
123
+ # k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
124
+ # v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
125
+ #
126
+ # k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
127
+ # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
128
+ #
129
+ # return q, k, v
130
+ head_dim = d // self.num_heads
131
+
132
+ # Process Query as in GQA
133
+ q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2) if not skip_query_processing else query
134
+
135
+ # Process Key and Value with MoE routing
136
+ key_flat = key.view(-1, d) # (B*S, d)
137
+ value_flat = value.view(-1, d) # (B*S, d)
138
+
139
+ # Get routing indices and weights for K
140
+ weights_k, indices_k = self.router(key_flat)
141
+ indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
142
+ weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
143
+
144
+ # Select and compute K projections for only the top_k experts
145
+ selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
146
+ k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
147
+ selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
148
+ selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
149
+
150
+ # Compute V using the same indices as K (since they share the same router)
151
+ selected_v_weights = self.v_experts[indices_k]
152
+ v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
153
+ selected_v = (v_proj * weights_k).sum(dim=1)
154
+ selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
155
+
156
+ # Reshape to GQA format: (B, G, S, head_dim)
157
+ k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
158
+ v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
159
+
160
+ if not self.use_flash_attention:
161
+ group_heads = self.num_heads // self.num_groups
162
+
163
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
164
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
165
+
166
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
167
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
168
+
169
+ return q, k, v
170
+
171
+ class SparseMoeAttention(GroupedMoeAttention):
172
+ """
173
+ Sparse MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
174
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
175
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
176
+ information from each token is used to identify related information parts from other tokens.
177
+
178
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
179
+ """
180
+ def __init__(
181
+ self,
182
+ embed_dim: int,
183
+ num_heads: int,
184
+ num_groups: int,
185
+ dropout: float = 0.0,
186
+ rope: RotaryPositionalEmbedding = None,
187
+ rope_only_for_query: bool = False,
188
+ use_relative_embeddings: bool = False,
189
+ max_seq_len: int = 1024,
190
+ use_flash_attention: bool = False,
191
+ is_causal: bool = False,
192
+ use_bias: bool = False,
193
+ num_experts: int = None,
194
+ num_query_experts: int = None,
195
+ num_active_query_heads: int = None,
196
+ *args,
197
+ **kwargs,
198
+ ):
199
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
200
+ self.num_active_query_heads = num_active_query_heads if num_active_query_heads is not None else num_groups
201
+ super(SparseMoeAttention, self).__init__(
202
+ embed_dim,
203
+ num_heads,
204
+ num_groups=num_groups,
205
+ dropout=dropout,
206
+ rope=rope,
207
+ rope_only_for_query=rope_only_for_query,
208
+ use_relative_embeddings=use_relative_embeddings,
209
+ max_seq_len=max_seq_len,
210
+ use_flash_attention=use_flash_attention,
211
+ is_causal=is_causal,
212
+ use_bias=use_bias,
213
+ num_experts=num_experts,
214
+ *args,
215
+ **kwargs,
216
+ )
217
+
218
+ def _init_q(self, embed_dim: int):
219
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_active_query_heads)
220
+ hidden_dim = embed_dim // (self.num_heads // self.num_groups)
221
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
222
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
223
+ self._init_query_experts()
224
+
225
+ def _init_query_experts(self):
226
+ torch.nn.init.xavier_uniform_(self.wq)
227
+ if self.use_bias:
228
+ torch.nn.init.zeros_(self.bq)
229
+
230
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
231
+ head_dim = d // self.num_heads
232
+
233
+ # Process Query with MoE routing
234
+ query_flat = query.view(-1, d) # (B*T, d)
235
+ weights_q, indices_q = self.router_q(query_flat)
236
+ indices_q = indices_q.view(-1, self.top_k_q) # (B*T, top_k_q)
237
+ weights_q = weights_q.view(-1, self.top_k_q, 1) # (B*T, top_k_q, 1)
238
+
239
+ # Select and compute Q projections for top_k experts
240
+ selected_q_weights = self.q_experts[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
241
+ q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
242
+ selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
243
+ selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
244
+
245
+ q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
246
+
247
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
248
+
249
+
250
+ # Others
251
+
252
+ class FlexAttention(MultiHeadAttention):
253
+ def __init__(
254
+ self,
255
+ embed_dim: int,
256
+ num_heads: int,
257
+ num_global_tokens: int = 16,
258
+ window_size: int = 128,
259
+ **kwargs
260
+ ):
261
+ super().__init__(embed_dim, num_heads, **kwargs)
262
+ self.num_global_tokens = num_global_tokens
263
+ self.window_size = window_size
264
+ self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
265
+
266
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
267
+ b, t, d = query.size()
268
+ head_dim = d // self.num_heads
269
+
270
+ # Split into global and local
271
+ x = torch.cat([self.global_tokens.expand(b, -1, -1), query], dim=1)
272
+ seq_len = x.size(1)
273
+ num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
274
+
275
+ # Project Q, K, V
276
+ q, k, v = self._forward_qkv(x, key, value, b, seq_len, d)
277
+
278
+ # Process Global-to-Global Attention
279
+ global_q = q[:, :, :self.num_global_tokens] # [B, H, G, head_dim]
280
+ global_k = k[:, :, :self.num_global_tokens]
281
+ global_v = v[:, :, :self.num_global_tokens]
282
+ global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
283
+
284
+ # Process Global-to-Local Attention
285
+ local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
286
+ local_v = v[:, :, self.num_global_tokens:]
287
+ # Apply RoPE to local_k if needed
288
+ if self.rope:
289
+ # Compute frequencies for entire local sequence
290
+ local_k = self.rope.forward_one(local_k)
291
+
292
+ global_local_attn = self._calculate_attn_weights(global_q, local_k, d) @ local_v
293
+
294
+ # Process Local-to-Local Attention (per window)
295
+ local_q = q[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
296
+ local_q = local_q.view(b, self.num_heads, num_windows, self.window_size, head_dim)
297
+ local_k = local_k.view(b, self.num_heads, num_windows, self.window_size, head_dim)
298
+ local_v = local_v.view(b, self.num_heads, num_windows, self.window_size, head_dim)
299
+
300
+ local_attn = []
301
+ for i in range(num_windows):
302
+ window_q = local_q[:, :, i] # [B, H, window_size, head_dim]
303
+ window_k = local_k[:, :, i]
304
+ window_v = local_v[:, :, i]
305
+
306
+ # Apply RoPE to window_q and window_k
307
+ if self.rope:
308
+ # Compute frequencies for this window
309
+ window_q, window_k = self.rope(window_q, window_k)
310
+
311
+ # Calculate attention for this window
312
+ attn = self._calculate_attn_weights(window_q, window_k, d)
313
+ attn_i = torch.einsum('bhij, bhjd -> bhid', attn, window_v)
314
+ local_attn.append(attn_i)
315
+ local_attn = torch.cat(local_attn, dim=2).view(b, self.num_heads, -1, head_dim)
316
+
317
+ # Combine all attention outputs
318
+ combined_attn = torch.cat([global_attn, global_local_attn, local_attn], dim=2)
319
+ output = self._calculate_output(combined_attn, v, b, t, d)
320
+ return self.out_proj(output)
321
+
322
+ class InfiniteAttention(MultiHeadAttention):
323
+ def __init__(
324
+ self,
325
+ embed_dim: int,
326
+ num_heads: int,
327
+ kernel_size: int = 128,
328
+ use_rotary: bool = True,
329
+ **kwargs
330
+ ):
331
+ super().__init__(embed_dim, num_heads, **kwargs)
332
+ self.kernel_size = kernel_size
333
+ self.use_rotary = use_rotary
334
+ self.register_buffer("fourier_basis", self._init_fourier_basis(embed_dim))
335
+
336
+ def _init_fourier_basis(self, embed_dim):
337
+ # Initialize Fourier features for positional encoding
338
+ freqs = torch.randn(embed_dim // 2)
339
+ return freqs
340
+
341
+ def _positional_encodings(self, x: torch.Tensor, device: torch.device):
342
+ """Generate positional encodings for arbitrary sequence length."""
343
+ seq_len = x.size(1)
344
+ pos = torch.arange(seq_len, device=device).float()
345
+ fourier_features = torch.einsum("d, s -> sd", self.fourier_basis, pos)
346
+ pe = torch.cat([torch.sin(fourier_features), torch.cos(fourier_features)], dim=1)
347
+ return pe.unsqueeze(0).expand(x.size(0), -1, -1)
348
+
349
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
350
+ b, t, d = query.size()
351
+ # Add positional encodings
352
+ pe = self._positional_encodings(query, query.device)
353
+ query = query + pe
354
+ key = key + pe
355
+
356
+ # Split into chunks for kernel-based attention
357
+ chunks = []
358
+ for i in range(0, t, self.kernel_size):
359
+ chunk = query[:, i:i + self.kernel_size]
360
+ chunks.append(chunk)
361
+
362
+ # Compute attention for each chunk
363
+ attn_output = []
364
+ for chunk in chunks:
365
+ q, k, v = self._forward_qkv(chunk, key, value, b, chunk.size(1), d)
366
+ # Use kernel approximation (e.g., Performer)
367
+ attn = self._performer_attention(q, k, v)
368
+ attn_output.append(attn)
369
+
370
+ # Concatenate and apply output projection
371
+ output = torch.cat(attn_output, dim=1)
372
+ return self.out_proj(output)
373
+
374
+ def _performer_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
375
+ # Performer kernel approximation (simplified)
376
+ # TODO: Replace with preferred kernel method
377
+ q = q / (q.shape[-1] ** 0.5)
378
+ attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
379
+ attn = torch.softmax(attn, dim=-1)
380
+ return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -77,29 +77,63 @@ class MoeFeedForward(nn.Module):
77
77
  return self.router.aux_loss
78
78
 
79
79
  def forward(self, x: torch.Tensor):
80
+ # orig_shape = x.shape
81
+ # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
+ #
83
+ # # Get routing weights and indices
84
+ # weights, indices = self.router(x) # [batch*seq_len, top_k]
85
+ #
86
+ # # Create expert masks and combine it with masks
87
+ # mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
+ # weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
89
+ #
90
+ # # Expert computation
91
+ # x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
92
+ #
93
+ # # First linear layer
94
+ # h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
95
+ # h = self._activate(h)
96
+ # h = self.dropout(h)
97
+ #
98
+ # # Second linear layer (projection back to embed_dim)
99
+ # out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
100
+ #
101
+ # # Weighted sum of expert outputs
102
+ # out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
103
+ #
104
+ # return out.view(*orig_shape)
80
105
  orig_shape = x.shape
81
106
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
107
 
83
108
  # Get routing weights and indices
84
- weights, indices = self.router(x) # [batch*seq_len, top_k]
109
+ weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
85
110
 
86
- # Create expert masks and combine it with masks
87
- mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
- weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
111
+ # Flatten indices and weights
112
+ batch_size = x.size(0)
113
+ top_k = indices.size(1)
114
+ indices = indices.view(-1) # [batch*seq_len * top_k]
115
+ weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
89
116
 
90
- # Expert computation
91
- x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
117
+ # Select only the relevant experts for each token
118
+ selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
119
+ selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
120
+ selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
121
+ selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
92
122
 
93
- # First linear layer
94
- h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
123
+ # Reshape x for batched computation
124
+ x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
125
+
126
+ # Compute only the selected experts
127
+ h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
95
128
  h = self._activate(h)
96
129
  h = self.dropout(h)
97
130
 
98
- # Second linear layer (projection back to embed_dim)
99
- out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
131
+ out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
100
132
 
101
- # Weighted sum of expert outputs
102
- out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
133
+ # Reshape back and apply weights
134
+ out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
135
+ weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
136
+ out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
103
137
 
104
138
  return out.view(*orig_shape)
105
139
 
@@ -1,13 +1,15 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
- from typing import Iterator
4
+ from typing import Iterator, Union
5
+ from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
6
+
5
7
 
6
8
  def sample(
7
- logits: torch.Tensor,
8
- temperature: float = 1.0,
9
- top_k: int = None,
10
- top_p: float = None,
9
+ logits: torch.Tensor,
10
+ temperature: float = 1.0,
11
+ top_k: int = None,
12
+ top_p: float = None,
11
13
  ) -> torch.Tensor:
12
14
  if temperature <= 0:
13
15
  raise ValueError("Temperature must be > 0")
@@ -45,6 +47,7 @@ def sample(
45
47
  # Sample from distribution
46
48
  return torch.multinomial(probs, num_samples=1)
47
49
 
50
+
48
51
  class Sampler:
49
52
  def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
50
53
  self.model = model.to(device)
@@ -52,12 +55,12 @@ class Sampler:
52
55
  self.end_token_id = end_token_id
53
56
 
54
57
  def _generate_token(
55
- self,
56
- input_ids: torch.Tensor,
57
- temperature: float,
58
- top_k: int,
59
- top_p: float ,
60
- attention_mask: torch.Tensor,
58
+ self,
59
+ input_ids: torch.Tensor,
60
+ temperature: float,
61
+ top_k: int,
62
+ top_p: float,
63
+ attention_mask: torch.Tensor,
61
64
  ) -> tuple[int, torch.Tensor, torch.Tensor]:
62
65
  # Forward pass to get next token logits
63
66
  outputs = self.model(input_ids, attention_mask=attention_mask)
@@ -82,14 +85,14 @@ class Sampler:
82
85
  )
83
86
 
84
87
  def __call__(
85
- self,
86
- initial_tokens: torch.Tensor,
87
- temperature: float = 1.0,
88
- top_k: int = None,
89
- top_p: float = None,
90
- max_seq_len: int = 50,
91
- attention_mask: torch.Tensor = None,
92
- no_grad: bool = True,
88
+ self,
89
+ initial_tokens: torch.Tensor,
90
+ temperature: float = 1.0,
91
+ top_k: int = None,
92
+ top_p: float = None,
93
+ max_seq_len: int = 50,
94
+ attention_mask: torch.Tensor = None,
95
+ no_grad: bool = True,
93
96
  ) -> Iterator[int]:
94
97
  # Convert initial tokens to tensor and move to device
95
98
  input_ids = initial_tokens
@@ -97,13 +100,70 @@ class Sampler:
97
100
  if no_grad:
98
101
  with torch.no_grad():
99
102
  for _ in range(max_seq_len):
100
- next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
103
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
104
+ attention_mask)
101
105
  yield next_token
102
106
  if next_token == self.end_token_id:
103
107
  break
104
108
  else:
105
109
  for _ in range(max_seq_len):
106
- next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
110
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
111
+ attention_mask)
107
112
  yield next_token
108
113
  if next_token == self.end_token_id:
109
- break
114
+ break
115
+
116
+
117
+ class SampleDecoder:
118
+ def __init__(self, sampler: Sampler, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
119
+ self.sampler = sampler
120
+ self.tokenizer = tokenizer
121
+ self.device = self.sampler.device
122
+
123
+ def tokenize_input(self, text: str):
124
+ tokenized = self.tokenizer(
125
+ text,
126
+ max_length=256,
127
+ truncation=True,
128
+ padding=False,
129
+ return_tensors='pt',
130
+ return_attention_mask=True
131
+ )
132
+ tokenized['input_ids'] = tokenized['input_ids'][:, :-1].to(self.device)
133
+ tokenized['attention_mask'] = tokenized['attention_mask'][:, :-1].to(self.device)
134
+ del tokenized['token_type_ids']
135
+ return tokenized
136
+
137
+ def ids_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
138
+ tokenized = self.tokenize_input(text)
139
+ return self.sampler(
140
+ tokenized['input_ids'],
141
+ temperature=temperature,
142
+ top_p=top_p,
143
+ max_seq_len=max_seq_len,
144
+ attention_mask=tokenized['attention_mask']
145
+ )
146
+
147
+ def txt_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
148
+ return map(
149
+ lambda x: self.tokenizer.decode([x]).replace('Ċ', '\n').replace('Ġ', ' '),
150
+ self.ids_iter(text, temperature, top_p, max_seq_len)
151
+ )
152
+
153
+ def txt(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
154
+ return text + ''.join(self.txt_iter(text, temperature, top_p, max_seq_len))
155
+
156
+ def print_stream(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
157
+ print(text, end='')
158
+ resp = text
159
+ for txt_token in self.txt_iter(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len):
160
+ print(txt_token, end='')
161
+ resp += txt_token
162
+ return resp
163
+
164
+ def __call__(self, text: str, print_stream: bool = False, temperature: float = 0.1, top_p: float = 0.9,
165
+ max_seq_len=256):
166
+ if print_stream:
167
+ return self.print_stream(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
168
+ else:
169
+ return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
@@ -1,133 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from rxnn.transformers.attention import MultiHeadAttention
4
-
5
- class FlexAttention(MultiHeadAttention):
6
- def __init__(
7
- self,
8
- embed_dim: int,
9
- num_heads: int,
10
- num_global_tokens: int = 16,
11
- window_size: int = 128,
12
- **kwargs
13
- ):
14
- super().__init__(embed_dim, num_heads, **kwargs)
15
- self.num_global_tokens = num_global_tokens
16
- self.window_size = window_size
17
- self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
18
-
19
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
20
- b, t, d = query.size()
21
- head_dim = d // self.num_heads
22
-
23
- # Split into global and local
24
- x = torch.cat([self.global_tokens.expand(b, -1, -1), query], dim=1)
25
- seq_len = x.size(1)
26
- num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
27
-
28
- # Project Q, K, V
29
- q, k, v = self._forward_qkv(x, key, value, b, seq_len, d)
30
-
31
- # Process Global-to-Global Attention
32
- global_q = q[:, :, :self.num_global_tokens] # [B, H, G, head_dim]
33
- global_k = k[:, :, :self.num_global_tokens]
34
- global_v = v[:, :, :self.num_global_tokens]
35
- global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
36
-
37
- # Process Global-to-Local Attention
38
- local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
39
- local_v = v[:, :, self.num_global_tokens:]
40
- # Apply RoPE to local_k if needed
41
- if self.rope:
42
- # Compute frequencies for entire local sequence
43
- local_k = self.rope.forward_one(local_k)
44
-
45
- global_local_attn = self._calculate_attn_weights(global_q, local_k, d) @ local_v
46
-
47
- # Process Local-to-Local Attention (per window)
48
- local_q = q[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
49
- local_q = local_q.view(b, self.num_heads, num_windows, self.window_size, head_dim)
50
- local_k = local_k.view(b, self.num_heads, num_windows, self.window_size, head_dim)
51
- local_v = local_v.view(b, self.num_heads, num_windows, self.window_size, head_dim)
52
-
53
- local_attn = []
54
- for i in range(num_windows):
55
- window_q = local_q[:, :, i] # [B, H, window_size, head_dim]
56
- window_k = local_k[:, :, i]
57
- window_v = local_v[:, :, i]
58
-
59
- # Apply RoPE to window_q and window_k
60
- if self.rope:
61
- # Compute frequencies for this window
62
- window_q, window_k = self.rope(window_q, window_k)
63
-
64
- # Calculate attention for this window
65
- attn = self._calculate_attn_weights(window_q, window_k, d)
66
- attn_i = torch.einsum('bhij, bhjd -> bhid', attn, window_v)
67
- local_attn.append(attn_i)
68
- local_attn = torch.cat(local_attn, dim=2).view(b, self.num_heads, -1, head_dim)
69
-
70
- # Combine all attention outputs
71
- combined_attn = torch.cat([global_attn, global_local_attn, local_attn], dim=2)
72
- output = self._calculate_output(combined_attn, v, b, t, d)
73
- return self.out_proj(output)
74
-
75
- class InfiniteAttention(MultiHeadAttention):
76
- def __init__(
77
- self,
78
- embed_dim: int,
79
- num_heads: int,
80
- kernel_size: int = 128,
81
- use_rotary: bool = True,
82
- **kwargs
83
- ):
84
- super().__init__(embed_dim, num_heads, **kwargs)
85
- self.kernel_size = kernel_size
86
- self.use_rotary = use_rotary
87
- self.register_buffer("fourier_basis", self._init_fourier_basis(embed_dim))
88
-
89
- def _init_fourier_basis(self, embed_dim):
90
- # Initialize Fourier features for positional encoding
91
- freqs = torch.randn(embed_dim // 2)
92
- return freqs
93
-
94
- def _positional_encodings(self, x: torch.Tensor, device: torch.device):
95
- """Generate positional encodings for arbitrary sequence length."""
96
- seq_len = x.size(1)
97
- pos = torch.arange(seq_len, device=device).float()
98
- fourier_features = torch.einsum("d, s -> sd", self.fourier_basis, pos)
99
- pe = torch.cat([torch.sin(fourier_features), torch.cos(fourier_features)], dim=1)
100
- return pe.unsqueeze(0).expand(x.size(0), -1, -1)
101
-
102
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
103
- b, t, d = query.size()
104
- # Add positional encodings
105
- pe = self._positional_encodings(query, query.device)
106
- query = query + pe
107
- key = key + pe
108
-
109
- # Split into chunks for kernel-based attention
110
- chunks = []
111
- for i in range(0, t, self.kernel_size):
112
- chunk = query[:, i:i + self.kernel_size]
113
- chunks.append(chunk)
114
-
115
- # Compute attention for each chunk
116
- attn_output = []
117
- for chunk in chunks:
118
- q, k, v = self._forward_qkv(chunk, key, value, b, chunk.size(1), d)
119
- # Use kernel approximation (e.g., Performer)
120
- attn = self._performer_attention(q, k, v)
121
- attn_output.append(attn)
122
-
123
- # Concatenate and apply output projection
124
- output = torch.cat(attn_output, dim=1)
125
- return self.out_proj(output)
126
-
127
- def _performer_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
128
- # Performer kernel approximation (simplified)
129
- # TODO: Replace with preferred kernel method
130
- q = q / (q.shape[-1] ** 0.5)
131
- attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
132
- attn = torch.softmax(attn, dim=-1)
133
- return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes