rxnn 0.2.56__py3-none-any.whl → 0.2.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -47,7 +47,8 @@ class StmMemoryAttention(nn.Module):
47
47
  else:
48
48
  return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
49
 
50
- def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
51
+ mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
51
52
  new_stm = torch.zeros_like(self.stm.memory)
52
53
  for i in range(self.num_layers):
53
54
  layer_stm = self.stm(i)
@@ -56,7 +57,7 @@ class StmMemoryAttention(nn.Module):
56
57
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
57
58
  encoded_layer_data = x[i]
58
59
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
59
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
60
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
60
61
  if self.use_gated_residual:
61
62
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
62
63
  else:
rxnn/rxt/models.py CHANGED
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
306
306
  def clone_reset_memory(self):
307
307
  self.model.stm.clone_detach_reset()
308
308
 
309
- def forward(self, x: torch.Tensor) -> torch.Tensor:
310
- return self.model(x)
309
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
310
+ return self.model(x, attention_mask=attention_mask)
311
311
 
312
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
313
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
rxnn/training/dataset.py CHANGED
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
943
943
  return self.get_tokenized_item(idx)
944
944
 
945
945
  def __len__(self) -> int:
946
- return len(self.episodes)
946
+ return len(self.inputs if self.is_pre_tokenized else self.episodes)
947
947
 
948
948
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
949
949
  split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
rxnn/training/models.py CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed)
207
+ return self.memory_attention(ed, attention_mask=attention_mask)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
rxnn/training/mrl.py CHANGED
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
38
  use_memory_warmup: Optional[bool]
39
+ debug_mode: Optional[bool]
40
+ debug_interval: Optional[int]
39
41
 
40
42
 
41
43
  class MrlStrategy(Enum):
@@ -109,7 +111,6 @@ class MRLTrainer:
109
111
  use_ddp: bool = False,
110
112
  use_amp: bool = False,
111
113
  dtype: torch.dtype = torch.float32,
112
- debug_mode: bool = False,
113
114
  ):
114
115
  """
115
116
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -140,7 +141,8 @@ class MRLTrainer:
140
141
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
141
142
  self.freeze_embeddings = self.shared_freeze_embeddings
142
143
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
- self.debug_mode = debug_mode
144
+ self.debug_mode = config.get('debug_mode', False)
145
+ self.debug_interval = config.get('debug_interval', 10)
144
146
  # Internal update epochs config
145
147
  self.shared_update_epochs = config.get('update_epochs', 10)
146
148
  self.update_epochs = self.shared_update_epochs
@@ -606,7 +608,7 @@ class MRLTrainer:
606
608
  self.scaler.unscale_(self.optimizer)
607
609
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
608
610
  error_if_nonfinite=False)
609
- if self.debug_mode:
611
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
610
612
  self._log_gradients()
611
613
  # 4.5 Run scaled optimization step
612
614
  self.scaler.step(self.optimizer)
@@ -625,7 +627,7 @@ class MRLTrainer:
625
627
  # 4.4 Clip gradient norms
626
628
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
627
629
  error_if_nonfinite=False)
628
- if self.debug_mode:
630
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
629
631
  self._log_gradients()
630
632
  # 4.5 Run scaled optimization step
631
633
  self.optimizer.step()
rxnn/training/rl.py CHANGED
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
33
33
  use_distributed_advantage_norm: Optional[bool]
34
34
  clip_critic_values: Optional[bool]
35
35
  critic_value_clip: Optional[float]
36
+ debug_mode: Optional[bool]
37
+ debug_interval: Optional[int]
36
38
 
37
39
 
38
40
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
41
+ def __init__(self, config: Optional[PPOConfig] = None):
40
42
  super(PPOAlgorithm, self).__init__()
41
43
 
42
44
  if config is None:
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
50
52
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
53
  self.clip_critic_values = config.get('clip_critic_values', True)
52
54
  self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
- self.debug_mode = debug_mode
55
+ self.debug_mode = config.get('debug_mode', False)
56
+ self.debug_interval = config.get('debug_interval', 10)
57
+ self.debug_step = 0
54
58
 
55
59
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
56
60
  # Critic loss with clipped values
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
98
102
  advantages = advantages.unsqueeze(-1)
99
103
 
100
104
  if self.debug_mode:
101
- print(
102
- f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
- print(
104
- f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
- print(
106
- f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
105
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
106
+ self.debug_step = 0
107
+ print(
108
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
109
+ print(
110
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
111
+ print(
112
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
113
+ else:
114
+ self.debug_step += 1
107
115
 
108
116
  # c) Clipped surrogate loss
109
117
  surr1 = ratio * advantages
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm, mask=mask)
113
+
114
+ if mask is not None:
115
+ mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
116
+
117
+ x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
114
118
  x = residual + x
115
119
  if self.use_post_norm:
116
120
  x = self.norm2(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.56
3
+ Version: 0.2.58
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,35 +5,35 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
8
+ rxnn/memory/attention.py,sha256=N78kzcqXXfcq5v43LUsLheVUs4JrKPeQissl-_XKXdk,3241
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=new_YXLe9vfIBPX-pmFRoV523d7yCjEgfTY06EaH3Ms,14605
12
+ rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
17
- rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
17
+ rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
20
- rxnn/training/mrl.py,sha256=185rZsaFVaAt4mYksuflzKPiDuEaHpjsc3vPzxt9ax0,61862
19
+ rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
+ rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
- rxnn/training/rl.py,sha256=iBEPC_gfydXtWkVORO3REMWvOtx60-0xB7MFzfghUK8,6825
22
+ rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
29
+ rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
31
  rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.56.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.56.dist-info/METADATA,sha256=qW9X-oP3LWHB0E6S0opHPzWjDmNBRy9DjOz5od4qutc,25997
38
- rxnn-0.2.56.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.56.dist-info/RECORD,,
36
+ rxnn-0.2.58.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.58.dist-info/METADATA,sha256=_NKVaBMYEbJadkBcUvDX1UprMlUu92JqgnBYx7R1J1c,25997
38
+ rxnn-0.2.58.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.58.dist-info/RECORD,,
File without changes
File without changes