rxnn 0.2.51__tar.gz → 0.2.52__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.51 → rxnn-0.2.52}/PKG-INFO +1 -1
  2. {rxnn-0.2.51 → rxnn-0.2.52}/pyproject.toml +1 -1
  3. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/memory/attention.py +8 -3
  4. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/mrl.py +15 -1
  5. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/rl.py +11 -2
  6. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/utils.py +11 -0
  7. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/layers.py +1 -1
  8. {rxnn-0.2.51 → rxnn-0.2.52}/LICENSE +0 -0
  9. {rxnn-0.2.51 → rxnn-0.2.52}/README.md +0 -0
  10. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/experimental/models.py +0 -0
  15. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/experimental/moe.py +0 -0
  16. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/memory/__init__.py +0 -0
  17. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/memory/stm.py +0 -0
  19. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/rxt/__init__.py +0 -0
  20. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/rxt/models.py +0 -0
  21. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/__init__.py +0 -0
  22. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/base.py +0 -0
  23. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/bml.py +0 -0
  24. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/callbacks.py +0 -0
  25. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/dataset.py +0 -0
  26. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/ddp.py +0 -0
  27. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/models.py +0 -0
  28. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/reward.py +0 -0
  29. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/scheduler.py +0 -0
  30. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/training/tokenizer.py +0 -0
  31. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.51 → rxnn-0.2.52}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.51
3
+ Version: 0.2.52
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.51"
7
+ version = "0.2.52"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -12,6 +12,7 @@ class StmMemoryAttention(nn.Module):
12
12
  per_slot_gate: bool = False,
13
13
  init_gate: float = 0.0,
14
14
  use_dynamic_gate: bool = False,
15
+ use_tanh_gate: bool = False,
15
16
  *args,
16
17
  **kwargs
17
18
  ):
@@ -24,6 +25,7 @@ class StmMemoryAttention(nn.Module):
24
25
  self.use_gated_residual = use_gated_residual
25
26
  self.per_slot_gate = per_slot_gate
26
27
  self.use_dynamic_gate = use_dynamic_gate
28
+ self.use_tanh_gate = use_tanh_gate
27
29
  if self.use_gated_residual:
28
30
  gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
29
31
  self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
@@ -37,10 +39,13 @@ class StmMemoryAttention(nn.Module):
37
39
  if self.use_dynamic_gate:
38
40
  mean_dim = -1 if self.per_slot_gate else [1, 2]
39
41
  gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
40
- layer_gate = torch.sigmoid(gate_input)
42
+ layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
41
43
  else:
42
- layer_gate = torch.sigmoid(gate)
43
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
+ layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
45
+ if self.use_tanh_gate:
46
+ return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
47
+ else:
48
+ return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
49
 
45
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
51
  new_stm = torch.zeros_like(self.stm.memory)
@@ -9,7 +9,7 @@ import random, os
9
9
  from ..transformers.sampler import BatchSampler
10
10
  from .callbacks import MrlTrainerCallback
11
11
  from .dataset import MrlCurriculumDataset
12
- from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
12
+ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict, get_gradient_norms
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
@@ -109,6 +109,7 @@ class MRLTrainer:
109
109
  use_ddp: bool = False,
110
110
  use_amp: bool = False,
111
111
  dtype: torch.dtype = torch.float32,
112
+ debug_mode: bool = False,
112
113
  ):
113
114
  """
114
115
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -139,6 +140,7 @@ class MRLTrainer:
139
140
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
140
141
  self.freeze_embeddings = self.shared_freeze_embeddings
141
142
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
+ self.debug_mode = debug_mode
142
144
  # Internal update epochs config
143
145
  self.shared_update_epochs = config.get('update_epochs', 10)
144
146
  self.update_epochs = self.shared_update_epochs
@@ -566,6 +568,14 @@ class MRLTrainer:
566
568
  else:
567
569
  return main_loss
568
570
 
571
+ def _log_gradients(self):
572
+ encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
573
+ decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
574
+ mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
575
+ print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
576
+ print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
577
+ print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
578
+
569
579
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
570
580
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
571
581
  # 1. Reset actor gradients
@@ -596,6 +606,8 @@ class MRLTrainer:
596
606
  self.scaler.unscale_(self.optimizer)
597
607
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
598
608
  error_if_nonfinite=False)
609
+ if self.debug_mode:
610
+ self._log_gradients()
599
611
  # 4.5 Run scaled optimization step
600
612
  self.scaler.step(self.optimizer)
601
613
  self.scaler.update()
@@ -613,6 +625,8 @@ class MRLTrainer:
613
625
  # 4.4 Clip gradient norms
614
626
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
615
627
  error_if_nonfinite=False)
628
+ if self.debug_mode:
629
+ self._log_gradients()
616
630
  # 4.5 Run scaled optimization step
617
631
  self.optimizer.step()
618
632
  # 5. Get float loss value for callbacks/writer
@@ -36,7 +36,7 @@ class PPOConfig(TypedDict):
36
36
 
37
37
 
38
38
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None):
39
+ def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
40
40
  super(PPOAlgorithm, self).__init__()
41
41
 
42
42
  if config is None:
@@ -49,7 +49,8 @@ class PPOAlgorithm(RlAlgorithm):
49
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
50
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
51
  self.clip_critic_values = config.get('clip_critic_values', True)
52
- self.critic_value_clip = config.get('critic_value_clip', 10.0)
52
+ self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
+ self.debug_mode = debug_mode
53
54
 
54
55
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
56
  # Critic loss with clipped values
@@ -96,6 +97,14 @@ class PPOAlgorithm(RlAlgorithm):
96
97
 
97
98
  advantages = advantages.unsqueeze(-1)
98
99
 
100
+ if self.debug_mode:
101
+ print(
102
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
+ print(
104
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
+ print(
106
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
107
+
99
108
  # c) Clipped surrogate loss
100
109
  surr1 = ratio * advantages
101
110
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ import torch.nn as nn
2
3
  from typing import TypedDict
3
4
 
4
5
 
@@ -142,3 +143,13 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
142
143
  'input_ids': combined_ids,
143
144
  'attention_mask': combined_mask
144
145
  }
146
+
147
+ def get_gradient_norms(model: nn.Module):
148
+ total_norm = 0
149
+ for p in model.parameters():
150
+ if p.grad is not None:
151
+ param_norm = p.grad.data.norm(2)
152
+ total_norm += param_norm.item() ** 2
153
+ total_norm = total_norm ** 0.5
154
+ mean_norm = total_norm / len(list(model.parameters()))
155
+ return total_norm, mean_norm
@@ -110,7 +110,7 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm)
113
+ x = self.memory_cross_attention(x, stm, stm, mask=mask)
114
114
  x = residual + x
115
115
  if self.use_post_norm:
116
116
  x = self.norm2(x)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes