rxnn 0.1.56__py3-none-any.whl → 0.1.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -683,7 +683,7 @@ class InfiniteAttention(MultiHeadAttention):
683
683
  attn = torch.softmax(attn, dim=-1)
684
684
  return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
685
685
 
686
- def init_moe_attention(
686
+ def init_experimental_attention(
687
687
  embed_dim: int,
688
688
  num_heads: int,
689
689
  attention_type: str,
@@ -8,7 +8,7 @@ from ..transformers.layers import ClassicTransformerLayer
8
8
  from ..transformers.models import ClassicTransformerDecoder
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..utils import get_model_size
11
- from .attention import init_moe_attention
11
+ from .attention import init_experimental_attention
12
12
 
13
13
 
14
14
  class MoeAttentionTransformerConfig(TypedDict):
@@ -77,11 +77,11 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
77
77
  use_flash_attention=use_flash_attention, dropout=att_dropout,
78
78
  max_seq_len=seq_len, is_causal=True)
79
79
  else:
80
- att_init = lambda: init_moe_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
81
- use_flash_attention=use_flash_attention, dropout=att_dropout,
82
- max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
83
- num_query_experts=att_num_query_experts,
84
- num_query_groups=att_num_query_groups)
80
+ att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
81
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
82
+ max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
83
+ num_query_experts=att_num_query_experts,
84
+ num_query_groups=att_num_query_groups)
85
85
 
86
86
  use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
87
87
 
rxnn/rxt/models.py CHANGED
@@ -9,7 +9,7 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..utils import get_model_size
12
-
12
+ from ..experimental.attention import init_experimental_attention
13
13
 
14
14
  class RxTAlphaComponentConfig(TypedDict):
15
15
  num_layers: int
@@ -31,6 +31,9 @@ class RxTAlphaComponentConfig(TypedDict):
31
31
  moe_top_k: int
32
32
  self_att_type: str
33
33
  cross_att_type: str
34
+ att_num_experts: int
35
+ att_num_query_experts: int
36
+ att_num_query_groups: int
34
37
 
35
38
 
36
39
  class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
@@ -58,14 +61,17 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
58
61
  moe_top_k: int = 1,
59
62
  self_att_type: str = 'gqa',
60
63
  cross_att_type: str = 'mqa',
64
+ att_num_experts: int = None,
65
+ att_num_query_experts: int = None,
66
+ att_num_query_groups: int = None,
61
67
  **kwargs
62
68
  ):
63
69
  super(RxTAlphaComponentBase, self).__init__(**kwargs)
64
70
  assert ff_activation in ['relu', 'gelu',
65
71
  'swish', 'silu', 'linear',
66
72
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
67
- assert self_att_type in ['mha', 'gqa', 'mqa'], 'Self-attention type could be "mha", "gqa", "mqa"'
68
- assert cross_att_type in ['mha', 'gqa', 'mqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa"'
73
+ assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
74
+ assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
69
75
 
70
76
  embedding = nn.Embedding(vocab_size, embed_dim)
71
77
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -73,6 +79,28 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
73
79
 
74
80
  ff_activation = get_activation_layer(ff_activation)
75
81
 
82
+ if self_att_type in ['mha', 'gqa', 'mqa']:
83
+ att_init = lambda: init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
84
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
85
+ max_seq_len=seq_len, is_causal=True)
86
+ else:
87
+ att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
88
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
89
+ max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
90
+ num_query_experts=att_num_query_experts,
91
+ num_query_groups=att_num_query_groups)
92
+
93
+ if cross_att_type in ['mha', 'gqa', 'mqa']:
94
+ cross_att_init = lambda: init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
95
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
96
+ max_seq_len=seq_len, is_causal=True)
97
+ else:
98
+ cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
99
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
100
+ max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
101
+ num_query_experts=att_num_query_experts,
102
+ num_query_groups=att_num_query_groups)
103
+
76
104
  layers = nn.ModuleList([
77
105
  ReactiveTransformerLayer(
78
106
  embed_dim,
@@ -84,13 +112,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
84
112
  ff_activation=ff_activation,
85
113
  ff_dropout=ff_dropout,
86
114
  use_rms_norm=use_rms_norm,
87
- self_attention=init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
88
- use_flash_attention=use_flash_attention, dropout=att_dropout,
89
- max_seq_len=seq_len, is_causal=is_causal),
90
- memory_cross_attention=init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
91
- use_flash_attention=use_flash_attention, dropout=att_dropout,
92
- max_seq_len=seq_len, rope_only_for_query=True,
93
- is_causal=is_causal)
115
+ self_attention=att_init(),
116
+ memory_cross_attention=cross_att_init(),
94
117
  ) for _ in range(num_layers)
95
118
  ])
96
119
  self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.56
3
+ Version: 0.1.57
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,13 +1,13 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=KiPefkFuDWyWwVwGT-sqHbjkucf1VypkmpaAKFG3PFE,34539
4
- rxnn/experimental/models.py,sha256=iprFSQDPK75zebDJBJ1i-mnNS9jlGf9RAIk-S0E9D-Q,4689
3
+ rxnn/experimental/attention.py,sha256=ivIqIc-15DWA_q-ITy2iaYmB7tffKVtiuqjdSH3mtS4,34548
4
+ rxnn/experimental/models.py,sha256=_i9kvQsAYPyMQo2VfMUTmtBs-mE2w75j1X-OHx03IJk,4743
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
8
8
  rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
9
9
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
10
+ rxnn/rxt/models.py,sha256=L5RvhORONmYSF_pVjP8HwiSeAypSNqfFi6Fogp2oJes,8543
11
11
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/training/base.py,sha256=gEWASLSuWR8UF8b2e-DYqkBZ1lBx0VsIm4kGf9eWSHM,11678
13
13
  rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.56.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.56.dist-info/METADATA,sha256=VdL35wYa0o-n0gZDPp0lYPkCIWFDpwnLjilZHavCFoc,16627
30
- rxnn-0.1.56.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.56.dist-info/RECORD,,
28
+ rxnn-0.1.57.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.57.dist-info/METADATA,sha256=K9kcLSS3CUYwFg9N-KhPK5J4tMmsgKYFqF6VkH8689U,16627
30
+ rxnn-0.1.57.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.57.dist-info/RECORD,,
File without changes
File without changes