rxnn 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -226,6 +226,98 @@ class DeepMoeAttention(GroupedMoeAttention):
226
226
 
227
227
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
228
228
 
229
+ class SparseQueryAttention(MultiHeadAttention):
230
+ """Sparse Grouped Query attention layer, with RoPE support"""
231
+
232
+ def __init__(
233
+ self,
234
+ embed_dim: int,
235
+ num_heads: int,
236
+ num_groups: int,
237
+ num_query_groups: int,
238
+ dropout: float = 0.0,
239
+ rope: RotaryPositionalEmbedding = None,
240
+ rope_only_for_query: bool = False,
241
+ use_relative_embeddings: bool = False,
242
+ max_seq_len: int = 1024,
243
+ use_flash_attention: bool = False,
244
+ is_causal: bool = False,
245
+ use_bias: bool = False,
246
+ *args,
247
+ **kwargs,
248
+ ):
249
+ self.num_groups = num_groups
250
+ self.num_query_groups = num_query_groups
251
+ super(SparseQueryAttention, self).__init__(
252
+ embed_dim,
253
+ num_heads,
254
+ dropout=dropout,
255
+ rope=rope,
256
+ rope_only_for_query=rope_only_for_query,
257
+ use_relative_embeddings=use_relative_embeddings,
258
+ max_seq_len=max_seq_len,
259
+ use_flash_attention=use_flash_attention,
260
+ is_causal=is_causal,
261
+ use_bias=use_bias,
262
+ *args,
263
+ **kwargs,
264
+ )
265
+ assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
266
+
267
+ def _init_kv(self, embed_dim: int):
268
+ self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
269
+ self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
270
+
271
+ def _init_q(self, embed_dim: int):
272
+ self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups))
273
+
274
+ def _init_out(self, embed_dim: int):
275
+ """Initialize output projection"""
276
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
277
+
278
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
279
+ """Transpose attention output back to (B, T, D) shape"""
280
+ return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
281
+
282
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
283
+ """Override query, key, and value projections for GQA case - split data into heads and groups"""
284
+ head_dim = d // self.num_heads
285
+ if not self.rel_embed:
286
+ q = self.q_proj(query).view(b, t, self.num_query_heads, head_dim).transpose(1, 2)
287
+ k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
288
+ v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
289
+ else:
290
+ group_heads = self.num_heads // self.num_groups
291
+ query_heads = self.num_query_heads // self.num_query_groups
292
+ # Process Q
293
+ q = self.q_proj(query).view(b, -1, self.num_query_groups, head_dim).transpose(1, 2) # (B, Q_G, T, head_dim)
294
+
295
+ # Process K and V
296
+ k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
297
+ v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
298
+
299
+ # Expand and flatten to 4D tensors
300
+ q = q.unsqueeze(2).expand(-1, -1, query_heads, -1, -1) # (B, Q_G, query_heads, T, head_dim)
301
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
302
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
303
+
304
+ q = q.flatten(start_dim=1, end_dim=2) # (B, Q, T, head_dim)
305
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
306
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
307
+ return q, k, v
308
+
309
+ def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
310
+ is_gqa = self.num_query_groups != self.num_groups
311
+ if self.use_flash_attention:
312
+ # Compute attention with FlashAttention
313
+ return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
314
+ else:
315
+ # Compute attention using optimized PyTorch implementation
316
+ return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
317
+
318
+
319
+
320
+
229
321
  class GroupedMoeAttentionSimplified(GroupedQueryAttention):
230
322
  """
231
323
  Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
@@ -607,8 +699,8 @@ def init_moe_attention(
607
699
  num_experts: int = None,
608
700
  num_query_experts: int = None,
609
701
  num_query_groups: int = None,
610
- ) -> GroupedQueryAttention:
611
- assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s'"
702
+ ) -> MultiHeadAttention:
703
+ assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s', 'sqa'"
612
704
 
613
705
  if attention_type == "gma":
614
706
  return GroupedMoeAttention(
@@ -642,6 +734,21 @@ def init_moe_attention(
642
734
  num_query_experts=num_query_experts,
643
735
  num_query_groups=num_query_groups,
644
736
  )
737
+ elif attention_type == 'sqa':
738
+ return SparseQueryAttention(
739
+ embed_dim,
740
+ num_heads,
741
+ gqa_groups,
742
+ num_query_groups,
743
+ dropout=dropout,
744
+ rope=rope,
745
+ use_relative_embeddings=use_relative_embeddings,
746
+ max_seq_len=max_seq_len,
747
+ rope_only_for_query=rope_only_for_query,
748
+ use_flash_attention=use_flash_attention,
749
+ is_causal=is_causal,
750
+ use_bias=use_bias,
751
+ )
645
752
  elif attention_type == "gma_s":
646
753
  return GroupedMoeAttentionSimplified(
647
754
  embed_dim,
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_s', 'dma_s'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_s", "dma_s".'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
69
69
 
70
70
  embedding = nn.Embedding(vocab_size, embed_dim)
71
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.54
3
+ Version: 0.1.55
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=ZYdRxz4ik7knk3VS_9Opzy6ZqVF98FIhSNjsmIUhGfk,29532
4
- rxnn/experimental/models.py,sha256=-BQn7gWlSHLpkAQdthPW5L9ZNzIBqSJS9tkm2N88jgw,4711
3
+ rxnn/experimental/attention.py,sha256=oPknT_PVcNwZvDwpZM7gmP4M_md_FW8oYwJDdQk1avM,34544
4
+ rxnn/experimental/models.py,sha256=iprFSQDPK75zebDJBJ1i-mnNS9jlGf9RAIk-S0E9D-Q,4689
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.54.dist-info/METADATA,sha256=FF9XlvOeROGLpVR5pHuuceoeXTzbMNJhEusmQdfPTD0,16627
30
- rxnn-0.1.54.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.54.dist-info/RECORD,,
28
+ rxnn-0.1.55.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.55.dist-info/METADATA,sha256=qiMp63aMlBdbvVvSJDL2bfW5XoR0PzNwN6pWdkfCuOM,16627
30
+ rxnn-0.1.55.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.55.dist-info/RECORD,,
File without changes
File without changes