rxnn 0.2.56__py3-none-any.whl → 0.2.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +3 -2
- rxnn/rxt/models.py +2 -2
- rxnn/training/dataset.py +1 -1
- rxnn/training/models.py +1 -1
- rxnn/training/mrl.py +6 -4
- rxnn/training/rl.py +16 -8
- rxnn/transformers/layers.py +5 -1
- {rxnn-0.2.56.dist-info → rxnn-0.2.58.dist-info}/METADATA +1 -1
- {rxnn-0.2.56.dist-info → rxnn-0.2.58.dist-info}/RECORD +11 -11
- {rxnn-0.2.56.dist-info → rxnn-0.2.58.dist-info}/LICENSE +0 -0
- {rxnn-0.2.56.dist-info → rxnn-0.2.58.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -47,7 +47,8 @@ class StmMemoryAttention(nn.Module):
|
|
47
47
|
else:
|
48
48
|
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
49
|
|
50
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
50
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
51
|
+
mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
|
51
52
|
new_stm = torch.zeros_like(self.stm.memory)
|
52
53
|
for i in range(self.num_layers):
|
53
54
|
layer_stm = self.stm(i)
|
@@ -56,7 +57,7 @@ class StmMemoryAttention(nn.Module):
|
|
56
57
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
57
58
|
encoded_layer_data = x[i]
|
58
59
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
59
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
60
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
|
60
61
|
if self.use_gated_residual:
|
61
62
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
62
63
|
else:
|
rxnn/rxt/models.py
CHANGED
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
306
306
|
def clone_reset_memory(self):
|
307
307
|
self.model.stm.clone_detach_reset()
|
308
308
|
|
309
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
-
return self.model(x)
|
309
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
310
|
+
return self.model(x, attention_mask=attention_mask)
|
311
311
|
|
312
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
313
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
rxnn/training/dataset.py
CHANGED
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
|
|
943
943
|
return self.get_tokenized_item(idx)
|
944
944
|
|
945
945
|
def __len__(self) -> int:
|
946
|
-
return len(self.episodes)
|
946
|
+
return len(self.inputs if self.is_pre_tokenized else self.episodes)
|
947
947
|
|
948
948
|
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
|
949
949
|
split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
|
rxnn/training/models.py
CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed)
|
207
|
+
return self.memory_attention(ed, attention_mask=attention_mask)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
rxnn/training/mrl.py
CHANGED
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
|
|
36
36
|
freeze_embeddings: Optional[bool]
|
37
37
|
embedding_lr: Optional[float]
|
38
38
|
use_memory_warmup: Optional[bool]
|
39
|
+
debug_mode: Optional[bool]
|
40
|
+
debug_interval: Optional[int]
|
39
41
|
|
40
42
|
|
41
43
|
class MrlStrategy(Enum):
|
@@ -109,7 +111,6 @@ class MRLTrainer:
|
|
109
111
|
use_ddp: bool = False,
|
110
112
|
use_amp: bool = False,
|
111
113
|
dtype: torch.dtype = torch.float32,
|
112
|
-
debug_mode: bool = False,
|
113
114
|
):
|
114
115
|
"""
|
115
116
|
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
@@ -140,7 +141,8 @@ class MRLTrainer:
|
|
140
141
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
141
142
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
142
143
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
143
|
-
self.debug_mode = debug_mode
|
144
|
+
self.debug_mode = config.get('debug_mode', False)
|
145
|
+
self.debug_interval = config.get('debug_interval', 10)
|
144
146
|
# Internal update epochs config
|
145
147
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
146
148
|
self.update_epochs = self.shared_update_epochs
|
@@ -606,7 +608,7 @@ class MRLTrainer:
|
|
606
608
|
self.scaler.unscale_(self.optimizer)
|
607
609
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
608
610
|
error_if_nonfinite=False)
|
609
|
-
if self.debug_mode:
|
611
|
+
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
610
612
|
self._log_gradients()
|
611
613
|
# 4.5 Run scaled optimization step
|
612
614
|
self.scaler.step(self.optimizer)
|
@@ -625,7 +627,7 @@ class MRLTrainer:
|
|
625
627
|
# 4.4 Clip gradient norms
|
626
628
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
627
629
|
error_if_nonfinite=False)
|
628
|
-
if self.debug_mode:
|
630
|
+
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
629
631
|
self._log_gradients()
|
630
632
|
# 4.5 Run scaled optimization step
|
631
633
|
self.optimizer.step()
|
rxnn/training/rl.py
CHANGED
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
|
|
33
33
|
use_distributed_advantage_norm: Optional[bool]
|
34
34
|
clip_critic_values: Optional[bool]
|
35
35
|
critic_value_clip: Optional[float]
|
36
|
+
debug_mode: Optional[bool]
|
37
|
+
debug_interval: Optional[int]
|
36
38
|
|
37
39
|
|
38
40
|
class PPOAlgorithm(RlAlgorithm):
|
39
|
-
def __init__(self, config: Optional[PPOConfig] = None
|
41
|
+
def __init__(self, config: Optional[PPOConfig] = None):
|
40
42
|
super(PPOAlgorithm, self).__init__()
|
41
43
|
|
42
44
|
if config is None:
|
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
|
|
50
52
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
53
|
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
54
|
self.critic_value_clip = config.get('critic_value_clip', 20.0)
|
53
|
-
self.debug_mode = debug_mode
|
55
|
+
self.debug_mode = config.get('debug_mode', False)
|
56
|
+
self.debug_interval = config.get('debug_interval', 10)
|
57
|
+
self.debug_step = 0
|
54
58
|
|
55
59
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
56
60
|
# Critic loss with clipped values
|
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
|
|
98
102
|
advantages = advantages.unsqueeze(-1)
|
99
103
|
|
100
104
|
if self.debug_mode:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
105
|
+
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
106
|
+
self.debug_step = 0
|
107
|
+
print(
|
108
|
+
f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
|
109
|
+
print(
|
110
|
+
f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
|
111
|
+
print(
|
112
|
+
f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
|
113
|
+
else:
|
114
|
+
self.debug_step += 1
|
107
115
|
|
108
116
|
# c) Clipped surrogate loss
|
109
117
|
surr1 = ratio * advantages
|
rxnn/transformers/layers.py
CHANGED
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
110
|
residual = x
|
111
111
|
if not self.use_post_norm:
|
112
112
|
x = self.norm2(x)
|
113
|
-
|
113
|
+
|
114
|
+
if mask is not None:
|
115
|
+
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
|
116
|
+
|
117
|
+
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
114
118
|
x = residual + x
|
115
119
|
if self.use_post_norm:
|
116
120
|
x = self.norm2(x)
|
@@ -5,35 +5,35 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=N78kzcqXXfcq5v43LUsLheVUs4JrKPeQissl-_XKXdk,3241
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
|
20
|
+
rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
29
|
+
rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
31
|
rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.58.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.58.dist-info/METADATA,sha256=_NKVaBMYEbJadkBcUvDX1UprMlUu92JqgnBYx7R1J1c,25997
|
38
|
+
rxnn-0.2.58.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.58.dist-info/RECORD,,
|
File without changes
|
File without changes
|