rxnn 0.1.49__py3-none-any.whl → 0.1.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,6 +83,8 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
83
83
  num_query_experts=att_num_query_experts,
84
84
  num_query_groups=att_num_query_groups)
85
85
 
86
+ use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
87
+
86
88
  self.model = ClassicTransformerDecoder(
87
89
  embed_dim,
88
90
  vocab_size,
@@ -99,6 +101,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
99
101
  ff_dropout=ff_dropout,
100
102
  use_rms_norm=use_rms_norm,
101
103
  self_attention=att_init(),
104
+ use_moe_att=use_moe_att,
102
105
  ) for _ in range(num_layers)
103
106
  ]),
104
107
  use_flash_attention=use_flash_attention,
@@ -22,6 +22,7 @@ class ReactiveTransformerLayer(nn.Module):
22
22
  use_moe: bool = False,
23
23
  num_experts: int = 1,
24
24
  moe_top_k: int = 1,
25
+ use_moe_att: bool = False,
25
26
  *args,
26
27
  **kwargs,
27
28
  ):
@@ -54,6 +55,7 @@ class ReactiveTransformerLayer(nn.Module):
54
55
  self.norm3 = nn.LayerNorm(embed_dim)
55
56
  self.use_post_norm = use_post_norm
56
57
  self.use_moe = use_moe
58
+ self.use_moe_att = use_moe_att
57
59
 
58
60
  def trainable_cross_attention_(self, is_trainable: bool):
59
61
  for param in self.memory_cross_attention.parameters():
@@ -62,12 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
62
64
  def moe_router_loss(self):
63
65
  ff_router_loss = self.ff.router_loss() if self.use_moe else None
64
66
  att_router_loss = None
65
- if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
66
- att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
67
- elif self.attention.router_loss is not None:
68
- att_router_loss = self.attention.router_loss()
69
- elif self.memory_cross_attention.router_loss is not None:
70
- att_router_loss = self.memory_cross_attention.router_loss()
67
+ if self.use_moe_att:
68
+ if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
69
+ att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
70
+ elif self.attention.router_loss is not None:
71
+ att_router_loss = self.attention.router_loss()
72
+ elif self.memory_cross_attention.router_loss is not None:
73
+ att_router_loss = self.memory_cross_attention.router_loss()
71
74
 
72
75
  if ff_router_loss is not None and att_router_loss is not None:
73
76
  return (ff_router_loss + att_router_loss) / 2
@@ -123,6 +126,7 @@ class ClassicTransformerLayer(nn.Module):
123
126
  use_moe: bool = False,
124
127
  num_experts: int = 1,
125
128
  moe_top_k: int = 1,
129
+ use_moe_att: bool = False,
126
130
  *args,
127
131
  **kwargs,
128
132
  ):
@@ -151,10 +155,11 @@ class ClassicTransformerLayer(nn.Module):
151
155
  self.norm2 = nn.LayerNorm(embed_dim)
152
156
  self.use_post_norm = use_post_norm
153
157
  self.use_moe = use_moe
158
+ self.use_moe_att = use_moe_att
154
159
 
155
160
  def moe_router_loss(self):
156
161
  ff_router_loss = self.ff.router_loss() if self.use_moe else None
157
- att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
162
+ att_router_loss = self.attention.router_loss() if self.use_moe_att and self.attention.router_loss is not None else None
158
163
 
159
164
  if ff_router_loss is not None and att_router_loss is not None:
160
165
  return (ff_router_loss + att_router_loss) / 2
@@ -38,8 +38,8 @@ class ReactiveTransformerBase(nn.Module):
38
38
  self.layers[i].trainable_cross_attention_(is_trainable)
39
39
 
40
40
  def moe_router_loss(self):
41
- return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
42
- self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
41
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
42
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
43
43
 
44
44
  def forward(self, x: torch.Tensor) -> torch.Tensor:
45
45
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
@@ -124,7 +124,7 @@ class ClassicTransformerBase(nn.Module):
124
124
  self.num_layers = len(layers) if layers else 0
125
125
 
126
126
  def moe_router_loss(self):
127
- return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
127
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att]).mean()
128
128
 
129
129
  def forward(self, x: torch.Tensor) -> torch.Tensor:
130
130
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.49
3
+ Version: 0.1.50
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
4
- rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
4
+ rxnn/experimental/models.py,sha256=-BQn7gWlSHLpkAQdthPW5L9ZNzIBqSJS9tkm2N88jgw,4711
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
- rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
21
+ rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
- rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
23
+ rxnn/transformers/models.py,sha256=_w5C7xvjT4-BFeMfzi57BQ51_fgaYZ4UK0SqUDE5Ooo,7266
24
24
  rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.49.dist-info/METADATA,sha256=PijR2z5P5nuTlOaWn-ylU_Loluy-e2HRgpMEc4TCohk,16627
30
- rxnn-0.1.49.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.49.dist-info/RECORD,,
28
+ rxnn-0.1.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.50.dist-info/METADATA,sha256=bIeDbrlcclSfD9oHf26i_sYepOTvTkpcwQMWpOm2jWc,16627
30
+ rxnn-0.1.50.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.50.dist-info/RECORD,,
File without changes
File without changes