rxnn 0.2.56__py3-none-any.whl → 0.2.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/dataset.py +1 -1
- rxnn/training/mrl.py +6 -4
- rxnn/training/rl.py +16 -8
- {rxnn-0.2.56.dist-info → rxnn-0.2.57.dist-info}/METADATA +1 -1
- {rxnn-0.2.56.dist-info → rxnn-0.2.57.dist-info}/RECORD +7 -7
- {rxnn-0.2.56.dist-info → rxnn-0.2.57.dist-info}/LICENSE +0 -0
- {rxnn-0.2.56.dist-info → rxnn-0.2.57.dist-info}/WHEEL +0 -0
rxnn/training/dataset.py
CHANGED
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
|
|
943
943
|
return self.get_tokenized_item(idx)
|
944
944
|
|
945
945
|
def __len__(self) -> int:
|
946
|
-
return len(self.episodes)
|
946
|
+
return len(self.inputs if self.is_pre_tokenized else self.episodes)
|
947
947
|
|
948
948
|
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
|
949
949
|
split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
|
rxnn/training/mrl.py
CHANGED
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
|
|
36
36
|
freeze_embeddings: Optional[bool]
|
37
37
|
embedding_lr: Optional[float]
|
38
38
|
use_memory_warmup: Optional[bool]
|
39
|
+
debug_mode: Optional[bool]
|
40
|
+
debug_interval: Optional[int]
|
39
41
|
|
40
42
|
|
41
43
|
class MrlStrategy(Enum):
|
@@ -109,7 +111,6 @@ class MRLTrainer:
|
|
109
111
|
use_ddp: bool = False,
|
110
112
|
use_amp: bool = False,
|
111
113
|
dtype: torch.dtype = torch.float32,
|
112
|
-
debug_mode: bool = False,
|
113
114
|
):
|
114
115
|
"""
|
115
116
|
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
@@ -140,7 +141,8 @@ class MRLTrainer:
|
|
140
141
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
141
142
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
142
143
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
143
|
-
self.debug_mode = debug_mode
|
144
|
+
self.debug_mode = config.get('debug_mode', False)
|
145
|
+
self.debug_interval = config.get('debug_interval', 10)
|
144
146
|
# Internal update epochs config
|
145
147
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
146
148
|
self.update_epochs = self.shared_update_epochs
|
@@ -606,7 +608,7 @@ class MRLTrainer:
|
|
606
608
|
self.scaler.unscale_(self.optimizer)
|
607
609
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
608
610
|
error_if_nonfinite=False)
|
609
|
-
if self.debug_mode:
|
611
|
+
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
610
612
|
self._log_gradients()
|
611
613
|
# 4.5 Run scaled optimization step
|
612
614
|
self.scaler.step(self.optimizer)
|
@@ -625,7 +627,7 @@ class MRLTrainer:
|
|
625
627
|
# 4.4 Clip gradient norms
|
626
628
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
627
629
|
error_if_nonfinite=False)
|
628
|
-
if self.debug_mode:
|
630
|
+
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
629
631
|
self._log_gradients()
|
630
632
|
# 4.5 Run scaled optimization step
|
631
633
|
self.optimizer.step()
|
rxnn/training/rl.py
CHANGED
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
|
|
33
33
|
use_distributed_advantage_norm: Optional[bool]
|
34
34
|
clip_critic_values: Optional[bool]
|
35
35
|
critic_value_clip: Optional[float]
|
36
|
+
debug_mode: Optional[bool]
|
37
|
+
debug_interval: Optional[int]
|
36
38
|
|
37
39
|
|
38
40
|
class PPOAlgorithm(RlAlgorithm):
|
39
|
-
def __init__(self, config: Optional[PPOConfig] = None
|
41
|
+
def __init__(self, config: Optional[PPOConfig] = None):
|
40
42
|
super(PPOAlgorithm, self).__init__()
|
41
43
|
|
42
44
|
if config is None:
|
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
|
|
50
52
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
53
|
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
54
|
self.critic_value_clip = config.get('critic_value_clip', 20.0)
|
53
|
-
self.debug_mode = debug_mode
|
55
|
+
self.debug_mode = config.get('debug_mode', False)
|
56
|
+
self.debug_interval = config.get('debug_interval', 10)
|
57
|
+
self.debug_step = 0
|
54
58
|
|
55
59
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
56
60
|
# Critic loss with clipped values
|
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
|
|
98
102
|
advantages = advantages.unsqueeze(-1)
|
99
103
|
|
100
104
|
if self.debug_mode:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
105
|
+
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
106
|
+
self.debug_step = 0
|
107
|
+
print(
|
108
|
+
f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
|
109
|
+
print(
|
110
|
+
f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
|
111
|
+
print(
|
112
|
+
f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
|
113
|
+
else:
|
114
|
+
self.debug_step += 1
|
107
115
|
|
108
116
|
# c) Clipped surrogate loss
|
109
117
|
surr1 = ratio * advantages
|
@@ -14,12 +14,12 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
|
20
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.57.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.57.dist-info/METADATA,sha256=xZ60cC1PUzse2BBDKPSOrFc_oVVLUdl6qKmZWUMaUa4,25997
|
38
|
+
rxnn-0.2.57.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.57.dist-info/RECORD,,
|
File without changes
|
File without changes
|