rxnn 0.2.50__py3-none-any.whl → 0.2.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -12,6 +12,7 @@ class StmMemoryAttention(nn.Module):
12
12
  per_slot_gate: bool = False,
13
13
  init_gate: float = 0.0,
14
14
  use_dynamic_gate: bool = False,
15
+ use_tanh_gate: bool = False,
15
16
  *args,
16
17
  **kwargs
17
18
  ):
@@ -24,6 +25,7 @@ class StmMemoryAttention(nn.Module):
24
25
  self.use_gated_residual = use_gated_residual
25
26
  self.per_slot_gate = per_slot_gate
26
27
  self.use_dynamic_gate = use_dynamic_gate
28
+ self.use_tanh_gate = use_tanh_gate
27
29
  if self.use_gated_residual:
28
30
  gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
29
31
  self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
@@ -37,10 +39,13 @@ class StmMemoryAttention(nn.Module):
37
39
  if self.use_dynamic_gate:
38
40
  mean_dim = -1 if self.per_slot_gate else [1, 2]
39
41
  gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
40
- layer_gate = torch.sigmoid(gate_input)
42
+ layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
41
43
  else:
42
- layer_gate = torch.sigmoid(gate)
43
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
+ layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
45
+ if self.use_tanh_gate:
46
+ return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
47
+ else:
48
+ return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
49
 
45
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
51
  new_stm = torch.zeros_like(self.stm.memory)
rxnn/training/models.py CHANGED
@@ -208,17 +208,18 @@ class MrlActorModel(nn.Module):
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
211
- def __init__(self, encoder: nn.Module, embed_dim: int,
212
- out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
211
+ def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
213
212
  super(MrlCriticModel, self).__init__(**kwargs)
214
213
  self.encoder = encoder
215
214
  self.value_head = nn.Sequential(
216
215
  GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
217
216
  nn.LayerNorm(embed_dim),
218
- nn.Linear(embed_dim, 1),
219
- get_activation_layer(out_activation)
217
+ nn.Linear(embed_dim, 1)
220
218
  )
221
- self.output_scale = output_scale
219
+ # Learnable scaling parameters
220
+ self.scale = nn.Parameter(torch.tensor(1.0))
221
+ self.shift = nn.Parameter(torch.tensor(0.0))
222
+
222
223
 
223
224
  def head_parameters(self) -> Iterator[nn.Parameter]:
224
225
  return self.value_head.parameters()
@@ -235,4 +236,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
235
236
  else:
236
237
  x = x.mean(dim=1)
237
238
 
238
- return self.value_head(x) * self.output_scale
239
+ return self.value_head(x) * self.scale + self.shift
rxnn/training/mrl.py CHANGED
@@ -9,7 +9,7 @@ import random, os
9
9
  from ..transformers.sampler import BatchSampler
10
10
  from .callbacks import MrlTrainerCallback
11
11
  from .dataset import MrlCurriculumDataset
12
- from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
12
+ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict, get_gradient_norms
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
@@ -109,6 +109,7 @@ class MRLTrainer:
109
109
  use_ddp: bool = False,
110
110
  use_amp: bool = False,
111
111
  dtype: torch.dtype = torch.float32,
112
+ debug_mode: bool = False,
112
113
  ):
113
114
  """
114
115
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -139,6 +140,7 @@ class MRLTrainer:
139
140
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
140
141
  self.freeze_embeddings = self.shared_freeze_embeddings
141
142
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
+ self.debug_mode = debug_mode
142
144
  # Internal update epochs config
143
145
  self.shared_update_epochs = config.get('update_epochs', 10)
144
146
  self.update_epochs = self.shared_update_epochs
@@ -566,6 +568,14 @@ class MRLTrainer:
566
568
  else:
567
569
  return main_loss
568
570
 
571
+ def _log_gradients(self):
572
+ encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
573
+ decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
574
+ mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
575
+ print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
576
+ print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
577
+ print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
578
+
569
579
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
570
580
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
571
581
  # 1. Reset actor gradients
@@ -596,6 +606,8 @@ class MRLTrainer:
596
606
  self.scaler.unscale_(self.optimizer)
597
607
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
598
608
  error_if_nonfinite=False)
609
+ if self.debug_mode:
610
+ self._log_gradients()
599
611
  # 4.5 Run scaled optimization step
600
612
  self.scaler.step(self.optimizer)
601
613
  self.scaler.update()
@@ -613,6 +625,8 @@ class MRLTrainer:
613
625
  # 4.4 Clip gradient norms
614
626
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
615
627
  error_if_nonfinite=False)
628
+ if self.debug_mode:
629
+ self._log_gradients()
616
630
  # 4.5 Run scaled optimization step
617
631
  self.optimizer.step()
618
632
  # 5. Get float loss value for callbacks/writer
rxnn/training/rl.py CHANGED
@@ -36,7 +36,7 @@ class PPOConfig(TypedDict):
36
36
 
37
37
 
38
38
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None):
39
+ def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
40
40
  super(PPOAlgorithm, self).__init__()
41
41
 
42
42
  if config is None:
@@ -49,12 +49,14 @@ class PPOAlgorithm(RlAlgorithm):
49
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
50
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
51
  self.clip_critic_values = config.get('clip_critic_values', True)
52
- self.critic_value_clip = config.get('critic_value_clip', 10.0)
52
+ self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
+ self.debug_mode = debug_mode
53
54
 
54
55
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
56
  # Critic loss with clipped values
56
57
  if self.clip_critic_values:
57
58
  values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
59
+ ref_values = torch.clamp(ref_values, -self.critic_value_clip, self.critic_value_clip)
58
60
  return self.critic_loss_fn(values, ref_values)
59
61
 
60
62
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
@@ -95,6 +97,14 @@ class PPOAlgorithm(RlAlgorithm):
95
97
 
96
98
  advantages = advantages.unsqueeze(-1)
97
99
 
100
+ if self.debug_mode:
101
+ print(
102
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
+ print(
104
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
+ print(
106
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
107
+
98
108
  # c) Clipped surrogate loss
99
109
  surr1 = ratio * advantages
100
110
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
rxnn/training/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ import torch.nn as nn
2
3
  from typing import TypedDict
3
4
 
4
5
 
@@ -142,3 +143,13 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
142
143
  'input_ids': combined_ids,
143
144
  'attention_mask': combined_mask
144
145
  }
146
+
147
+ def get_gradient_norms(model: nn.Module):
148
+ total_norm = 0
149
+ for p in model.parameters():
150
+ if p.grad is not None:
151
+ param_norm = p.grad.data.norm(2)
152
+ total_norm += param_norm.item() ** 2
153
+ total_norm = total_norm ** 0.5
154
+ mean_norm = total_norm / len(list(model.parameters()))
155
+ return total_norm, mean_norm
@@ -110,7 +110,7 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm)
113
+ x = self.memory_cross_attention(x, stm, stm, mask=mask)
114
114
  x = residual + x
115
115
  if self.use_post_norm:
116
116
  x = self.norm2(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.50
3
+ Version: 0.2.52
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,7 +5,7 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=kan6UNPTjLfO7zKNp92hGooldgWPi3li_2-_L5xiErs,2784
8
+ rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,24 +16,24 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=L2emJM06u7B9f9T1dFsGXzXX-rsV77ND7L1pAM9Z_Ow,9051
20
- rxnn/training/mrl.py,sha256=IOi_xbQ47RPgv_2ucT9EkPeWLGBRlgPxKHFeQsYc3Pw,61074
19
+ rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
20
+ rxnn/training/mrl.py,sha256=185rZsaFVaAt4mYksuflzKPiDuEaHpjsc3vPzxt9ax0,61862
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
- rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
22
+ rxnn/training/rl.py,sha256=iBEPC_gfydXtWkVORO3REMWvOtx60-0xB7MFzfghUK8,6825
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
- rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
+ rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
29
+ rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
31
  rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.50.dist-info/METADATA,sha256=MmlWkWUki9ErQnJ24yP2R9mDykQewDHDcyCQhzopZAw,25997
38
- rxnn-0.2.50.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.50.dist-info/RECORD,,
36
+ rxnn-0.2.52.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.52.dist-info/METADATA,sha256=_0O3SB2pBSkic0UMXV4_nlYJjFw9bqUFroOs3A4FuW0,25997
38
+ rxnn-0.2.52.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.52.dist-info/RECORD,,
File without changes
File without changes