rxnn 0.2.50__py3-none-any.whl → 0.2.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +8 -3
- rxnn/training/models.py +7 -6
- rxnn/training/mrl.py +15 -1
- rxnn/training/rl.py +12 -2
- rxnn/training/utils.py +11 -0
- rxnn/transformers/layers.py +1 -1
- {rxnn-0.2.50.dist-info → rxnn-0.2.52.dist-info}/METADATA +1 -1
- {rxnn-0.2.50.dist-info → rxnn-0.2.52.dist-info}/RECORD +10 -10
- {rxnn-0.2.50.dist-info → rxnn-0.2.52.dist-info}/LICENSE +0 -0
- {rxnn-0.2.50.dist-info → rxnn-0.2.52.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -12,6 +12,7 @@ class StmMemoryAttention(nn.Module):
|
|
12
12
|
per_slot_gate: bool = False,
|
13
13
|
init_gate: float = 0.0,
|
14
14
|
use_dynamic_gate: bool = False,
|
15
|
+
use_tanh_gate: bool = False,
|
15
16
|
*args,
|
16
17
|
**kwargs
|
17
18
|
):
|
@@ -24,6 +25,7 @@ class StmMemoryAttention(nn.Module):
|
|
24
25
|
self.use_gated_residual = use_gated_residual
|
25
26
|
self.per_slot_gate = per_slot_gate
|
26
27
|
self.use_dynamic_gate = use_dynamic_gate
|
28
|
+
self.use_tanh_gate = use_tanh_gate
|
27
29
|
if self.use_gated_residual:
|
28
30
|
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
29
31
|
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
@@ -37,10 +39,13 @@ class StmMemoryAttention(nn.Module):
|
|
37
39
|
if self.use_dynamic_gate:
|
38
40
|
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
39
41
|
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
40
|
-
layer_gate = torch.sigmoid(gate_input)
|
42
|
+
layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
|
41
43
|
else:
|
42
|
-
layer_gate = torch.sigmoid(gate)
|
43
|
-
|
44
|
+
layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
|
45
|
+
if self.use_tanh_gate:
|
46
|
+
return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
|
47
|
+
else:
|
48
|
+
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
44
49
|
|
45
50
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46
51
|
new_stm = torch.zeros_like(self.stm.memory)
|
rxnn/training/models.py
CHANGED
@@ -208,17 +208,18 @@ class MrlActorModel(nn.Module):
|
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
211
|
-
def __init__(self, encoder: nn.Module, embed_dim: int,
|
212
|
-
out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
211
|
+
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
213
212
|
super(MrlCriticModel, self).__init__(**kwargs)
|
214
213
|
self.encoder = encoder
|
215
214
|
self.value_head = nn.Sequential(
|
216
215
|
GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
|
217
216
|
nn.LayerNorm(embed_dim),
|
218
|
-
nn.Linear(embed_dim, 1)
|
219
|
-
get_activation_layer(out_activation)
|
217
|
+
nn.Linear(embed_dim, 1)
|
220
218
|
)
|
221
|
-
|
219
|
+
# Learnable scaling parameters
|
220
|
+
self.scale = nn.Parameter(torch.tensor(1.0))
|
221
|
+
self.shift = nn.Parameter(torch.tensor(0.0))
|
222
|
+
|
222
223
|
|
223
224
|
def head_parameters(self) -> Iterator[nn.Parameter]:
|
224
225
|
return self.value_head.parameters()
|
@@ -235,4 +236,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
235
236
|
else:
|
236
237
|
x = x.mean(dim=1)
|
237
238
|
|
238
|
-
return self.value_head(x) * self.
|
239
|
+
return self.value_head(x) * self.scale + self.shift
|
rxnn/training/mrl.py
CHANGED
@@ -9,7 +9,7 @@ import random, os
|
|
9
9
|
from ..transformers.sampler import BatchSampler
|
10
10
|
from .callbacks import MrlTrainerCallback
|
11
11
|
from .dataset import MrlCurriculumDataset
|
12
|
-
from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
|
12
|
+
from .utils import smart_concat, smart_concat_critic_states, TokenizedDict, get_gradient_norms
|
13
13
|
from .rl import RlAlgorithm
|
14
14
|
from .reward import MrlRewardMode, MrlRewardModel
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
@@ -109,6 +109,7 @@ class MRLTrainer:
|
|
109
109
|
use_ddp: bool = False,
|
110
110
|
use_amp: bool = False,
|
111
111
|
dtype: torch.dtype = torch.float32,
|
112
|
+
debug_mode: bool = False,
|
112
113
|
):
|
113
114
|
"""
|
114
115
|
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
@@ -139,6 +140,7 @@ class MRLTrainer:
|
|
139
140
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
140
141
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
141
142
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
143
|
+
self.debug_mode = debug_mode
|
142
144
|
# Internal update epochs config
|
143
145
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
144
146
|
self.update_epochs = self.shared_update_epochs
|
@@ -566,6 +568,14 @@ class MRLTrainer:
|
|
566
568
|
else:
|
567
569
|
return main_loss
|
568
570
|
|
571
|
+
def _log_gradients(self):
|
572
|
+
encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
|
573
|
+
decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
|
574
|
+
mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
|
575
|
+
print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
|
576
|
+
print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
|
577
|
+
print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
|
578
|
+
|
569
579
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
570
580
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
571
581
|
# 1. Reset actor gradients
|
@@ -596,6 +606,8 @@ class MRLTrainer:
|
|
596
606
|
self.scaler.unscale_(self.optimizer)
|
597
607
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
598
608
|
error_if_nonfinite=False)
|
609
|
+
if self.debug_mode:
|
610
|
+
self._log_gradients()
|
599
611
|
# 4.5 Run scaled optimization step
|
600
612
|
self.scaler.step(self.optimizer)
|
601
613
|
self.scaler.update()
|
@@ -613,6 +625,8 @@ class MRLTrainer:
|
|
613
625
|
# 4.4 Clip gradient norms
|
614
626
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
615
627
|
error_if_nonfinite=False)
|
628
|
+
if self.debug_mode:
|
629
|
+
self._log_gradients()
|
616
630
|
# 4.5 Run scaled optimization step
|
617
631
|
self.optimizer.step()
|
618
632
|
# 5. Get float loss value for callbacks/writer
|
rxnn/training/rl.py
CHANGED
@@ -36,7 +36,7 @@ class PPOConfig(TypedDict):
|
|
36
36
|
|
37
37
|
|
38
38
|
class PPOAlgorithm(RlAlgorithm):
|
39
|
-
def __init__(self, config: Optional[PPOConfig] = None):
|
39
|
+
def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
|
40
40
|
super(PPOAlgorithm, self).__init__()
|
41
41
|
|
42
42
|
if config is None:
|
@@ -49,12 +49,14 @@ class PPOAlgorithm(RlAlgorithm):
|
|
49
49
|
self.entropy_coef = config.get('entropy_coef', 0.01)
|
50
50
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
51
|
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
|
-
self.critic_value_clip = config.get('critic_value_clip',
|
52
|
+
self.critic_value_clip = config.get('critic_value_clip', 20.0)
|
53
|
+
self.debug_mode = debug_mode
|
53
54
|
|
54
55
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
55
56
|
# Critic loss with clipped values
|
56
57
|
if self.clip_critic_values:
|
57
58
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
59
|
+
ref_values = torch.clamp(ref_values, -self.critic_value_clip, self.critic_value_clip)
|
58
60
|
return self.critic_loss_fn(values, ref_values)
|
59
61
|
|
60
62
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
@@ -95,6 +97,14 @@ class PPOAlgorithm(RlAlgorithm):
|
|
95
97
|
|
96
98
|
advantages = advantages.unsqueeze(-1)
|
97
99
|
|
100
|
+
if self.debug_mode:
|
101
|
+
print(
|
102
|
+
f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
|
103
|
+
print(
|
104
|
+
f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
|
105
|
+
print(
|
106
|
+
f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
|
107
|
+
|
98
108
|
# c) Clipped surrogate loss
|
99
109
|
surr1 = ratio * advantages
|
100
110
|
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
rxnn/training/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import torch
|
2
|
+
import torch.nn as nn
|
2
3
|
from typing import TypedDict
|
3
4
|
|
4
5
|
|
@@ -142,3 +143,13 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
142
143
|
'input_ids': combined_ids,
|
143
144
|
'attention_mask': combined_mask
|
144
145
|
}
|
146
|
+
|
147
|
+
def get_gradient_norms(model: nn.Module):
|
148
|
+
total_norm = 0
|
149
|
+
for p in model.parameters():
|
150
|
+
if p.grad is not None:
|
151
|
+
param_norm = p.grad.data.norm(2)
|
152
|
+
total_norm += param_norm.item() ** 2
|
153
|
+
total_norm = total_norm ** 0.5
|
154
|
+
mean_norm = total_norm / len(list(model.parameters()))
|
155
|
+
return total_norm, mean_norm
|
rxnn/transformers/layers.py
CHANGED
@@ -110,7 +110,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
110
|
residual = x
|
111
111
|
if not self.use_post_norm:
|
112
112
|
x = self.norm2(x)
|
113
|
-
x = self.memory_cross_attention(x, stm, stm)
|
113
|
+
x = self.memory_cross_attention(x, stm, stm, mask=mask)
|
114
114
|
x = residual + x
|
115
115
|
if self.use_post_norm:
|
116
116
|
x = self.norm2(x)
|
@@ -5,7 +5,7 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,24 +16,24 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
|
20
|
+
rxnn/training/mrl.py,sha256=185rZsaFVaAt4mYksuflzKPiDuEaHpjsc3vPzxt9ax0,61862
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=iBEPC_gfydXtWkVORO3REMWvOtx60-0xB7MFzfghUK8,6825
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
|
-
rxnn/training/utils.py,sha256=
|
25
|
+
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
29
|
+
rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
31
|
rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.52.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.52.dist-info/METADATA,sha256=_0O3SB2pBSkic0UMXV4_nlYJjFw9bqUFroOs3A4FuW0,25997
|
38
|
+
rxnn-0.2.52.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.52.dist-info/RECORD,,
|
File without changes
|
File without changes
|