rxnn 0.2.56__tar.gz → 0.2.57__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.56 → rxnn-0.2.57}/PKG-INFO +1 -1
  2. {rxnn-0.2.56 → rxnn-0.2.57}/pyproject.toml +1 -1
  3. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/dataset.py +1 -1
  4. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/mrl.py +6 -4
  5. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/rl.py +16 -8
  6. {rxnn-0.2.56 → rxnn-0.2.57}/LICENSE +0 -0
  7. {rxnn-0.2.56 → rxnn-0.2.57}/README.md +0 -0
  8. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/rxt/models.py +0 -0
  20. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/models.py +0 -0
  26. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.56 → rxnn-0.2.57}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.56
3
+ Version: 0.2.57
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.56"
7
+ version = "0.2.57"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
943
943
  return self.get_tokenized_item(idx)
944
944
 
945
945
  def __len__(self) -> int:
946
- return len(self.episodes)
946
+ return len(self.inputs if self.is_pre_tokenized else self.episodes)
947
947
 
948
948
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
949
949
  split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
38
  use_memory_warmup: Optional[bool]
39
+ debug_mode: Optional[bool]
40
+ debug_interval: Optional[int]
39
41
 
40
42
 
41
43
  class MrlStrategy(Enum):
@@ -109,7 +111,6 @@ class MRLTrainer:
109
111
  use_ddp: bool = False,
110
112
  use_amp: bool = False,
111
113
  dtype: torch.dtype = torch.float32,
112
- debug_mode: bool = False,
113
114
  ):
114
115
  """
115
116
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -140,7 +141,8 @@ class MRLTrainer:
140
141
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
141
142
  self.freeze_embeddings = self.shared_freeze_embeddings
142
143
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
- self.debug_mode = debug_mode
144
+ self.debug_mode = config.get('debug_mode', False)
145
+ self.debug_interval = config.get('debug_interval', 10)
144
146
  # Internal update epochs config
145
147
  self.shared_update_epochs = config.get('update_epochs', 10)
146
148
  self.update_epochs = self.shared_update_epochs
@@ -606,7 +608,7 @@ class MRLTrainer:
606
608
  self.scaler.unscale_(self.optimizer)
607
609
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
608
610
  error_if_nonfinite=False)
609
- if self.debug_mode:
611
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
610
612
  self._log_gradients()
611
613
  # 4.5 Run scaled optimization step
612
614
  self.scaler.step(self.optimizer)
@@ -625,7 +627,7 @@ class MRLTrainer:
625
627
  # 4.4 Clip gradient norms
626
628
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
627
629
  error_if_nonfinite=False)
628
- if self.debug_mode:
630
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
629
631
  self._log_gradients()
630
632
  # 4.5 Run scaled optimization step
631
633
  self.optimizer.step()
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
33
33
  use_distributed_advantage_norm: Optional[bool]
34
34
  clip_critic_values: Optional[bool]
35
35
  critic_value_clip: Optional[float]
36
+ debug_mode: Optional[bool]
37
+ debug_interval: Optional[int]
36
38
 
37
39
 
38
40
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
41
+ def __init__(self, config: Optional[PPOConfig] = None):
40
42
  super(PPOAlgorithm, self).__init__()
41
43
 
42
44
  if config is None:
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
50
52
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
53
  self.clip_critic_values = config.get('clip_critic_values', True)
52
54
  self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
- self.debug_mode = debug_mode
55
+ self.debug_mode = config.get('debug_mode', False)
56
+ self.debug_interval = config.get('debug_interval', 10)
57
+ self.debug_step = 0
54
58
 
55
59
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
56
60
  # Critic loss with clipped values
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
98
102
  advantages = advantages.unsqueeze(-1)
99
103
 
100
104
  if self.debug_mode:
101
- print(
102
- f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
- print(
104
- f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
- print(
106
- f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
105
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
106
+ self.debug_step = 0
107
+ print(
108
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
109
+ print(
110
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
111
+ print(
112
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
113
+ else:
114
+ self.debug_step += 1
107
115
 
108
116
  # c) Clipped surrogate loss
109
117
  surr1 = ratio * advantages
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes