rxnn 0.2.60__tar.gz → 0.2.62__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.60 → rxnn-0.2.62}/PKG-INFO +1 -1
- {rxnn-0.2.60 → rxnn-0.2.62}/pyproject.toml +1 -1
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/memory/attention.py +14 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/rxt/models.py +3 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/mrl.py +48 -17
- {rxnn-0.2.60 → rxnn-0.2.62}/LICENSE +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/README.md +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.60 → rxnn-0.2.62}/src/rxnn/utils.py +0 -0
@@ -13,6 +13,8 @@ class StmMemoryAttention(nn.Module):
|
|
13
13
|
init_gate: float = 0.0,
|
14
14
|
use_dynamic_gate: bool = False,
|
15
15
|
use_tanh_gate: bool = False,
|
16
|
+
debug_mode: bool = False,
|
17
|
+
debug_interval: int = 10,
|
16
18
|
*args,
|
17
19
|
**kwargs
|
18
20
|
):
|
@@ -30,6 +32,10 @@ class StmMemoryAttention(nn.Module):
|
|
30
32
|
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
31
33
|
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
32
34
|
|
35
|
+
self.debug_mode = debug_mode
|
36
|
+
self.debug_interval = debug_interval
|
37
|
+
self.debug_step = 0
|
38
|
+
|
33
39
|
def update_max_len(self, max_seq_len: int):
|
34
40
|
for i in range(self.num_layers):
|
35
41
|
if self.attention_layers[i].rope is not None:
|
@@ -58,6 +64,14 @@ class StmMemoryAttention(nn.Module):
|
|
58
64
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
59
65
|
encoded_layer_data = x[i]
|
60
66
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
67
|
+
|
68
|
+
if self.debug_mode and self.training:
|
69
|
+
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
70
|
+
self.debug_step = 0
|
71
|
+
print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
|
72
|
+
else:
|
73
|
+
self.debug_step += 1
|
74
|
+
|
61
75
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
62
76
|
if self.use_gated_residual:
|
63
77
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
@@ -254,6 +254,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
254
254
|
residual_per_slot_gate: bool = False,
|
255
255
|
residual_init_gate: float = 0.0,
|
256
256
|
use_dynamic_residual_gate: bool = False,
|
257
|
+
debug_mode: bool = False,
|
258
|
+
debug_interval: int = 10,
|
257
259
|
**kwargs,
|
258
260
|
):
|
259
261
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
@@ -284,6 +286,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
284
286
|
stm, attention_layers, memory_norm_layers,
|
285
287
|
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
|
286
288
|
init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
|
289
|
+
debug_mode=debug_mode, debug_interval=debug_interval,
|
287
290
|
)
|
288
291
|
|
289
292
|
def freeze(self):
|
@@ -22,6 +22,9 @@ class MrlConfig(TypedDict):
|
|
22
22
|
memory_lr: Optional[float]
|
23
23
|
critic_lr: float
|
24
24
|
critic_encoder_lr: Optional[float]
|
25
|
+
encoder_lr: Optional[float]
|
26
|
+
encoder_memory_lr: Optional[float]
|
27
|
+
memory_attn_lr: Optional[float]
|
25
28
|
max_seq_len: int
|
26
29
|
critic_max_len: int
|
27
30
|
weight_decay: Optional[float]
|
@@ -68,6 +71,9 @@ class CurriculumConfig(TypedDict):
|
|
68
71
|
memory_lr: Optional[float]
|
69
72
|
critic_lr: Optional[float]
|
70
73
|
critic_encoder_lr: Optional[float]
|
74
|
+
encoder_lr: Optional[float]
|
75
|
+
encoder_memory_lr: Optional[float]
|
76
|
+
memory_attn_lr: Optional[float]
|
71
77
|
weight_decay: Optional[float]
|
72
78
|
critic_weight_decay: Optional[float]
|
73
79
|
update_epochs: Optional[int]
|
@@ -95,7 +101,10 @@ class MrlTrajectoryEpisode(TypedDict):
|
|
95
101
|
reset_stm: bool
|
96
102
|
steps: list[MrlTrajectoryStep]
|
97
103
|
|
98
|
-
OptimField: TypeAlias = Literal[
|
104
|
+
OptimField: TypeAlias = Literal[
|
105
|
+
'lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr',
|
106
|
+
'memory_lr', 'encoder_lr', 'encoder_memory_lr', 'memory_attn_lr'
|
107
|
+
]
|
99
108
|
|
100
109
|
class MRLTrainer:
|
101
110
|
def __init__(
|
@@ -181,6 +190,9 @@ class MRLTrainer:
|
|
181
190
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
182
191
|
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
183
192
|
'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
|
193
|
+
'encoder_lr': config.get('encoder_lr', config.get('lr', 3e-4)),
|
194
|
+
'encoder_memory_lr': config.get('encoder_memory_lr', config.get('memory_lr', 5e-4)),
|
195
|
+
'memory_attn_lr': config.get('memory_attn_lr', config.get('memory_lr', 5e-4)),
|
184
196
|
}
|
185
197
|
else:
|
186
198
|
self.base_optim_config = {
|
@@ -190,6 +202,7 @@ class MRLTrainer:
|
|
190
202
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
191
203
|
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
192
204
|
'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
|
205
|
+
'encoder_lr': config.get('encoder_lr', config.get('lr', 3e-4)),
|
193
206
|
}
|
194
207
|
|
195
208
|
self.optim_config = self.base_optim_config
|
@@ -574,9 +587,18 @@ class MRLTrainer:
|
|
574
587
|
encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
|
575
588
|
decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
|
576
589
|
mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
|
577
|
-
print(f"Encoder grad norm - total: {encoder_total:.
|
578
|
-
print(f"Decoder grad norm - total: {decoder_total:.
|
579
|
-
print(f"Memory attention grad norm - total: {mem_att_total:.
|
590
|
+
print(f"Encoder grad norm - total: {encoder_total:.6f}, mean: {encoder_mean:.6f}")
|
591
|
+
print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
|
592
|
+
print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
|
593
|
+
# decoder's cross att
|
594
|
+
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[0] for layer in self.actor.decoder.model.layers]
|
595
|
+
print(f"Decoder cross-att mean total norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
|
596
|
+
|
597
|
+
mem_att_norms = [get_gradient_norms(layer)[0] for layer in self.actor.memory_attention.model.attention_layers]
|
598
|
+
print(f"Memory attention layers mean total norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
|
599
|
+
|
600
|
+
enc_ff_norms = [get_gradient_norms(layer.ff)[0] for layer in self.actor.encoder.model.layers]
|
601
|
+
print(f"Encoder ff mean total norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
|
580
602
|
|
581
603
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
582
604
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
@@ -969,14 +991,16 @@ class MRLTrainer:
|
|
969
991
|
unfreeze_lr: float,
|
970
992
|
) -> torch.optim.Optimizer:
|
971
993
|
memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
|
972
|
-
|
994
|
+
encoder_memory_lr = self.optim_config['encoder_memory_lr'] if 'encoder_memory_lr' in self.optim_config else self.optim_config['encoder_lr']
|
995
|
+
memory_attn_lr = self.optim_config['memory_attn_lr'] if 'memory_attn_lr' in self.optim_config else self.optim_config['lr']
|
996
|
+
model_lr, embedding_lr, encoder_lr = self.optim_config['lr'], self.optim_config['embedding_lr'], self.optim_config['encoder_lr']
|
973
997
|
|
974
998
|
if mode == 'update':
|
975
999
|
params = [
|
976
1000
|
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
977
|
-
{'params': self.actor.encoder.not_memory_parameters(), 'lr':
|
978
|
-
{'params': self.actor.encoder.memory_parameters(), 'lr':
|
979
|
-
{'params': self.actor.memory_attention_parameters(), 'lr':
|
1001
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
|
1002
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
|
1003
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
|
980
1004
|
{'params': self.actor.decoder.memory_parameters(), 'lr': unfreeze_lr},
|
981
1005
|
{'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
982
1006
|
]
|
@@ -993,17 +1017,17 @@ class MRLTrainer:
|
|
993
1017
|
params = [
|
994
1018
|
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
995
1019
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
996
|
-
{'params': self.actor.encoder.memory_parameters(), 'lr':
|
997
|
-
{'params': self.actor.memory_attention_parameters(), 'lr':
|
1020
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
|
1021
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
|
998
1022
|
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
999
1023
|
{'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
1000
1024
|
]
|
1001
1025
|
else:
|
1002
1026
|
params = [
|
1003
1027
|
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
1004
|
-
{'params': self.actor.encoder.not_memory_parameters(), 'lr':
|
1005
|
-
{'params': self.actor.encoder.memory_parameters(), 'lr':
|
1006
|
-
{'params': self.actor.memory_attention_parameters(), 'lr':
|
1028
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
|
1029
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
|
1030
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
|
1007
1031
|
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
1008
1032
|
{'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
|
1009
1033
|
]
|
@@ -1028,11 +1052,14 @@ class MRLTrainer:
|
|
1028
1052
|
def has_param(field: OptimField) -> bool:
|
1029
1053
|
return field in config and config[field] is not None
|
1030
1054
|
|
1031
|
-
optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
|
1055
|
+
optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'encoder_lr']
|
1056
|
+
mem_optim_params: list[OptimField] = ['memory_lr', 'encoder_memory_lr', 'memory_attn_lr']
|
1032
1057
|
|
1033
1058
|
has_any_optim_param = any(
|
1034
1059
|
has_param(field) for field in optim_params
|
1035
|
-
) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and
|
1060
|
+
) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and any(
|
1061
|
+
has_param(field) for field in mem_optim_params
|
1062
|
+
))
|
1036
1063
|
|
1037
1064
|
if has_any_optim_param:
|
1038
1065
|
if config.get('separate_memory_lr', False):
|
@@ -1044,7 +1071,10 @@ class MRLTrainer:
|
|
1044
1071
|
self.base_optim_config['critic_weight_decay']),
|
1045
1072
|
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
1046
1073
|
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
1047
|
-
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
|
1074
|
+
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr']),
|
1075
|
+
'encoder_lr': config.get('encoder_lr', self.base_optim_config['encoder_lr']),
|
1076
|
+
'encoder_memory_lr': config.get('encoder_memory_lr', self.base_optim_config['encoder_memory_lr']),
|
1077
|
+
'memory_attn_lr': config.get('memory_attn_lr', self.base_optim_config['memory_attn_lr']),
|
1048
1078
|
}
|
1049
1079
|
else:
|
1050
1080
|
self.optim_config = {
|
@@ -1054,7 +1084,8 @@ class MRLTrainer:
|
|
1054
1084
|
'critic_weight_decay': config.get('critic_weight_decay',
|
1055
1085
|
self.base_optim_config['critic_weight_decay']),
|
1056
1086
|
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
1057
|
-
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
|
1087
|
+
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr']),
|
1088
|
+
'encoder_lr': config.get('encoder_lr', self.base_optim_config['encoder_lr']),
|
1058
1089
|
}
|
1059
1090
|
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
1060
1091
|
elif self.optim_config != self.base_optim_config:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|