rxnn 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
323
323
 
324
324
  # Key/Value MoE routing
325
325
  B, S, D = key.shape
326
- key_flat = key.reshape(B * S, D)
326
+ key_flat = key.reshape(-1, D)
327
327
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
328
  weights = weights.view(B, S, self.num_groups, 1)
329
329
  indices = indices.view(B, S, self.num_groups)
@@ -331,35 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
331
331
  # Compute all experts' projections
332
332
  # Shape: (B*S, num_experts, head_dim)
333
333
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
- v_all = torch.einsum('bd,edh->beh', value.view(B*S, D), self.wv)
334
+ v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
335
335
 
336
- # Reshape to [B, S, num_experts, head_dim]
337
- k_all = k_all.view(B, S, self.num_experts, -1)
338
- v_all = v_all.view(B, S, self.num_experts, -1)
339
-
340
- print('k_all', k_all.size())
341
- print('v_all', v_all.size())
342
-
343
-
344
- # Gather top-k experts and weights
345
- # Expand indices to [B, S, num_groups, head_dim]
346
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
347
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
348
- selected_v = torch.gather(v_all, 2, expanded_indices)
336
+ if self.use_bias:
337
+ k_all += self.bk
338
+ v_all += self.bv
349
339
 
350
- print('selected_k', selected_k.size())
351
- print('selected_v', selected_v.size())
340
+ # Get results for all heads
341
+ k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
342
+ v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
352
343
 
353
- # Weighted sum
354
- weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
355
- weighted_v = (selected_v * weights).sum(dim=2)
344
+ # Gather top-k experts using expanded indices
345
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
346
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
347
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
356
348
 
357
- print('weighted_k', weighted_k.size())
358
- print('weighted_v', weighted_v.size())
349
+ # Weighted
350
+ weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
351
+ weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
359
352
 
360
353
  # Reshape to GQA format
361
- k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
362
- v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
354
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
355
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
363
356
 
364
357
  if not self.use_flash_attention:
365
358
  group_heads = self.num_heads // self.num_groups
@@ -370,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
370
363
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
371
364
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
372
365
 
373
- print('q', q.size())
374
- print('k', k.size())
375
- print('v', v.size())
376
-
377
366
  return q, k, v
378
367
 
379
368
 
@@ -458,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
458
447
 
459
448
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
460
449
  B, T, D = query.shape
461
- query_flat = query.reshape(B * T, D)
450
+ query_flat = query.reshape(-1, D)
462
451
  weights_q, indices_q = self.query_router(query_flat)
463
452
  weights_q = weights_q.view(B, T, self.num_query_groups, 1)
464
453
  indices_q = indices_q.view(B, T, self.num_query_groups)
465
454
 
466
455
  # Compute all query experts
467
456
  q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
457
+ if self.use_bias:
458
+ q_all += self.bq
459
+
468
460
  q_all = q_all.view(B, T, self.num_query_experts, -1)
469
461
 
470
462
  # Gather top-k experts
@@ -472,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
472
464
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
473
465
 
474
466
  # Weighted sum
475
- q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
476
- q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
467
+ q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
468
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
477
469
 
478
470
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
479
471
 
480
-
481
472
  # Others
482
473
 
483
474
  class FlexAttention(MultiHeadAttention):
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+ from torch.backends.cuda import sdp_kernel, SDPBackend
4
5
  import math
5
6
  from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
6
7
 
@@ -17,7 +18,7 @@ class MultiHeadAttention(nn.Module):
17
18
  rope_only_for_query: bool = False,
18
19
  use_relative_embeddings: bool = False,
19
20
  max_seq_len: int = 1024,
20
- use_flash_attention: bool = False,
21
+ use_flash_attention: bool = True,
21
22
  is_causal: bool = False,
22
23
  use_bias: bool = False,
23
24
  *args,
@@ -101,13 +102,14 @@ class MultiHeadAttention(nn.Module):
101
102
 
102
103
  def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
103
104
  mask: torch.Tensor = None, enable_gqa: bool = False):
104
- attn_output = F.scaled_dot_product_attention(
105
- q, k, v,
106
- attn_mask=mask if not self.is_causal else None,
107
- dropout_p=self.dropout.p if self.training else 0.0,
108
- is_causal=self.is_causal,
109
- enable_gqa=enable_gqa,
110
- )
105
+ with sdp_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
106
+ attn_output = F.scaled_dot_product_attention(
107
+ q, k, v,
108
+ attn_mask=mask if not self.is_causal else None,
109
+ dropout_p=self.dropout.p if self.training else 0.0,
110
+ is_causal=self.is_causal,
111
+ enable_gqa=enable_gqa,
112
+ )
111
113
  return self._transpose_output(attn_output, b, t, d)
112
114
 
113
115
  def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.31
3
+ Version: 0.1.33
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=lqRIR2V7pA5LZQ1x_nhUvQBeJ83DLD6tXdoee-W6ZdU,29833
3
+ rxnn/experimental/attention.py,sha256=h9pEv_70NKpD5KOQOFP3h-IJKzh7Wbnaxka4Bd3rdd8,29745
4
4
  rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
19
+ rxnn/transformers/attention.py,sha256=Nox986BH9qq4rDYLiYmfj1DeMeULF3akexIl99MPccM,14331
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.31.dist-info/METADATA,sha256=jy1wfyPS-nD34GAMFxcWR_dFn6ZJogJ2-IDxIVmLdmI,16627
30
- rxnn-0.1.31.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.31.dist-info/RECORD,,
28
+ rxnn-0.1.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.33.dist-info/METADATA,sha256=m3DWDnTu7Lx1kHYPIAQCdKU8t4QZBdqG0QcSIFvB924,16627
30
+ rxnn-0.1.33.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.33.dist-info/RECORD,,
File without changes
File without changes