rxnn 0.1.31__tar.gz → 0.1.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.1.31 → rxnn-0.1.32}/PKG-INFO +1 -1
- {rxnn-0.1.31 → rxnn-0.1.32}/pyproject.toml +1 -1
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/attention.py +23 -32
- {rxnn-0.1.31 → rxnn-0.1.32}/LICENSE +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/README.md +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/base.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/utils.py +0 -0
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
323
323
|
|
324
324
|
# Key/Value MoE routing
|
325
325
|
B, S, D = key.shape
|
326
|
-
key_flat = key.reshape(
|
326
|
+
key_flat = key.reshape(-1, D)
|
327
327
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
328
328
|
weights = weights.view(B, S, self.num_groups, 1)
|
329
329
|
indices = indices.view(B, S, self.num_groups)
|
@@ -331,35 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
331
331
|
# Compute all experts' projections
|
332
332
|
# Shape: (B*S, num_experts, head_dim)
|
333
333
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
334
|
-
v_all = torch.einsum('bd,edh->beh', value.view(
|
334
|
+
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
335
335
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
print('k_all', k_all.size())
|
341
|
-
print('v_all', v_all.size())
|
342
|
-
|
343
|
-
|
344
|
-
# Gather top-k experts and weights
|
345
|
-
# Expand indices to [B, S, num_groups, head_dim]
|
346
|
-
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
347
|
-
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
348
|
-
selected_v = torch.gather(v_all, 2, expanded_indices)
|
336
|
+
if self.use_bias:
|
337
|
+
k_all += self.bk
|
338
|
+
v_all += self.bv
|
349
339
|
|
350
|
-
|
351
|
-
|
340
|
+
# Get results for all heads
|
341
|
+
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
342
|
+
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
352
343
|
|
353
|
-
#
|
354
|
-
|
355
|
-
|
344
|
+
# Gather top-k experts using expanded indices
|
345
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
346
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
347
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
356
348
|
|
357
|
-
|
358
|
-
|
349
|
+
# Weighted
|
350
|
+
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
351
|
+
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
359
352
|
|
360
353
|
# Reshape to GQA format
|
361
|
-
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
362
|
-
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
354
|
+
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
355
|
+
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
363
356
|
|
364
357
|
if not self.use_flash_attention:
|
365
358
|
group_heads = self.num_heads // self.num_groups
|
@@ -370,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
370
363
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
371
364
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
372
365
|
|
373
|
-
print('q', q.size())
|
374
|
-
print('k', k.size())
|
375
|
-
print('v', v.size())
|
376
|
-
|
377
366
|
return q, k, v
|
378
367
|
|
379
368
|
|
@@ -458,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
458
447
|
|
459
448
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
460
449
|
B, T, D = query.shape
|
461
|
-
query_flat = query.reshape(
|
450
|
+
query_flat = query.reshape(-1, D)
|
462
451
|
weights_q, indices_q = self.query_router(query_flat)
|
463
452
|
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
464
453
|
indices_q = indices_q.view(B, T, self.num_query_groups)
|
465
454
|
|
466
455
|
# Compute all query experts
|
467
456
|
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
457
|
+
if self.use_bias:
|
458
|
+
q_all += self.bq
|
459
|
+
|
468
460
|
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
469
461
|
|
470
462
|
# Gather top-k experts
|
@@ -472,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
472
464
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
473
465
|
|
474
466
|
# Weighted sum
|
475
|
-
q =
|
476
|
-
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B,
|
467
|
+
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
468
|
+
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
477
469
|
|
478
470
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
479
471
|
|
480
|
-
|
481
472
|
# Others
|
482
473
|
|
483
474
|
class FlexAttention(MultiHeadAttention):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|