rxnn 0.2.57__py3-none-any.whl → 0.2.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -47,7 +47,10 @@ class StmMemoryAttention(nn.Module):
47
47
  else:
48
48
  return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
49
 
50
- def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
51
+ if attention_mask is not None:
52
+ print(attention_mask.size())
53
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
51
54
  new_stm = torch.zeros_like(self.stm.memory)
52
55
  for i in range(self.num_layers):
53
56
  layer_stm = self.stm(i)
@@ -56,7 +59,7 @@ class StmMemoryAttention(nn.Module):
56
59
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
57
60
  encoded_layer_data = x[i]
58
61
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
59
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
62
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
60
63
  if self.use_gated_residual:
61
64
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
62
65
  else:
rxnn/rxt/models.py CHANGED
@@ -103,13 +103,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
103
103
  if cross_att_type in ['mha', 'gqa', 'mqa']:
104
104
  cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
105
105
  use_flash_attention=use_flash_attention, dropout=att_dropout,
106
- max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
106
+ max_seq_len=seq_len, is_causal=False, rope_only_for_query=True)
107
107
  else:
108
108
  cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
109
109
  cross_att_groups or att_groups, rope=rope,
110
110
  use_flash_attention=use_flash_attention,
111
111
  dropout=att_dropout,
112
- max_seq_len=seq_len, is_causal=is_causal,
112
+ max_seq_len=seq_len, is_causal=False,
113
113
  num_experts=att_experts,
114
114
  num_query_experts=att_query_experts,
115
115
  num_query_groups=cross_att_query_groups or att_query_groups,
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
306
306
  def clone_reset_memory(self):
307
307
  self.model.stm.clone_detach_reset()
308
308
 
309
- def forward(self, x: torch.Tensor) -> torch.Tensor:
310
- return self.model(x)
309
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
310
+ return self.model(x, attention_mask=attention_mask)
311
311
 
312
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
313
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
rxnn/training/models.py CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed)
207
+ return self.memory_attention(ed, attention_mask=attention_mask)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm, mask=mask)
113
+
114
+ if mask is not None:
115
+ mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
116
+
117
+ x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
114
118
  x = residual + x
115
119
  if self.use_post_norm:
116
120
  x = self.norm2(x)
@@ -108,6 +108,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
108
108
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
109
109
  x = super().forward(x) # apply embeddings
110
110
  if attention_mask is not None:
111
+ print(attention_mask.size())
111
112
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
112
113
 
113
114
  hidden_states = []
@@ -213,6 +214,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
213
214
  if attention_mask is not None:
214
215
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
215
216
  elif attention_mask is not None:
217
+ print(attention_mask.size())
216
218
  mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
217
219
  else:
218
220
  mask = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.57
3
+ Version: 0.2.59
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,18 +5,18 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
8
+ rxnn/memory/attention.py,sha256=Mt9LC0DK5asjN4pTqWOLDAoGzAgg3FhQoUw6baDJ3NI,3309
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=new_YXLe9vfIBPX-pmFRoV523d7yCjEgfTY06EaH3Ms,14605
12
+ rxnn/rxt/models.py,sha256=jd1UVBUWJzWdw7Rjcvo9k5BXCJriQ0khuVszqEyfD7M,14665
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
17
17
  rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
19
+ rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
20
  rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
@@ -26,14 +26,14 @@ rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
29
+ rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
- rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
31
+ rxnn/transformers/models.py,sha256=bHR5gy74aV20Sl0371vVCvb3Z2pcqYTdOSnrLSkAIiI,10802
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.57.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.57.dist-info/METADATA,sha256=xZ60cC1PUzse2BBDKPSOrFc_oVVLUdl6qKmZWUMaUa4,25997
38
- rxnn-0.2.57.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.57.dist-info/RECORD,,
36
+ rxnn-0.2.59.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.59.dist-info/METADATA,sha256=yj3LmcYzHTe7O5Y3LuDkCbjoTcY01WRQxaPQ9G4a8GA,25997
38
+ rxnn-0.2.59.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.59.dist-info/RECORD,,
File without changes
File without changes