rxnn 0.2.68__tar.gz → 0.2.70__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.68 → rxnn-0.2.70}/PKG-INFO +1 -1
- {rxnn-0.2.68 → rxnn-0.2.70}/pyproject.toml +1 -1
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/attention.py +8 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/dataset.py +14 -5
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/models.py +2 -2
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/mrl.py +20 -7
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/layers.py +17 -2
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/models.py +21 -20
- {rxnn-0.2.68 → rxnn-0.2.70}/LICENSE +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/README.md +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/utils.py +0 -0
@@ -64,6 +64,8 @@ class StmMemoryAttention(nn.Module):
|
|
64
64
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
65
65
|
encoded_layer_data = x[i]
|
66
66
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
67
|
+
if torch.isnan(normalized_layer_stm).any():
|
68
|
+
print(f"NaN detected in {i} layer memory norm output")
|
67
69
|
|
68
70
|
if self.debug_mode and self.training:
|
69
71
|
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
@@ -72,7 +74,13 @@ class StmMemoryAttention(nn.Module):
|
|
72
74
|
else:
|
73
75
|
self.debug_step += 1
|
74
76
|
|
77
|
+
if torch.isnan(encoded_layer_data).any():
|
78
|
+
print(f"NaN detected in {i} layer encoded data input")
|
79
|
+
|
75
80
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
81
|
+
if torch.isnan(new_layer_stm).any():
|
82
|
+
print(f"NaN detected in {i} layer memory attention output")
|
83
|
+
|
76
84
|
if self.use_gated_residual:
|
77
85
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
78
86
|
else:
|
@@ -4,7 +4,7 @@ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
|
|
4
4
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
5
5
|
from .tokenizer import load_tokenizer_from_hf_hub
|
6
6
|
|
7
|
-
from typing import Union, TypedDict, Optional, TypeAlias, Any
|
7
|
+
from typing import Union, TypedDict, Optional, TypeAlias, Any, Literal
|
8
8
|
|
9
9
|
|
10
10
|
class BaseDataset(Dataset):
|
@@ -854,8 +854,8 @@ class EncoderSftDataset(BaseInteractionDataset):
|
|
854
854
|
'labels': labels
|
855
855
|
}
|
856
856
|
|
857
|
-
|
858
|
-
MrlDataItem: TypeAlias = dict[str, Union[dict[
|
857
|
+
ItemFields: TypeAlias = Literal['input_ids', 'attention_mask']
|
858
|
+
MrlDataItem: TypeAlias = dict[str, Union[dict[ItemFields, torch.Tensor], list[dict[str, dict[ItemFields, torch.Tensor]]]]]
|
859
859
|
|
860
860
|
|
861
861
|
class MrlCurriculumDataset(Dataset):
|
@@ -1031,7 +1031,7 @@ class MrlCurriculumDataset(Dataset):
|
|
1031
1031
|
"""Collate function for MRL curriculum dataset with nested interactions"""
|
1032
1032
|
|
1033
1033
|
def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> \
|
1034
|
-
dict[str, dict[
|
1034
|
+
dict[str, dict[ItemFields, torch.Tensor]]:
|
1035
1035
|
"""Helper to collate a batch of interactions"""
|
1036
1036
|
return {
|
1037
1037
|
'query': {
|
@@ -1047,13 +1047,22 @@ class MrlCurriculumDataset(Dataset):
|
|
1047
1047
|
batch_interactions = [x['interactions'] for x in batch]
|
1048
1048
|
transposed_interactions = list(zip(*batch_interactions))
|
1049
1049
|
|
1050
|
-
|
1050
|
+
def has_nans(tensor: dict[ItemFields, torch.Tensor]) -> bool:
|
1051
|
+
return torch.isnan(tensor['input_ids']).any().item() or torch.isnan(tensor['attention_mask']).any().item()
|
1052
|
+
|
1053
|
+
results: MrlDataItem = {
|
1051
1054
|
**collate_interaction_batch(batch), # Collate initial query and answer
|
1052
1055
|
'interactions': [
|
1053
1056
|
collate_interaction_batch(step_batch) for step_batch in transposed_interactions
|
1054
1057
|
]
|
1055
1058
|
}
|
1056
1059
|
|
1060
|
+
assert not has_nans(results['query']), "NaN in query"
|
1061
|
+
assert not has_nans(results['answer']), "NaN in answer"
|
1062
|
+
assert not any([(has_nans(item['query']) or has_nans(item['answer'])) for item in results['interactions']]), "NaN in interactions"
|
1063
|
+
|
1064
|
+
return results
|
1065
|
+
|
1057
1066
|
|
1058
1067
|
class MrlDatasetItem(TypedDict):
|
1059
1068
|
steps: int
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
from enum import Enum
|
4
|
-
from typing import Literal, Iterator
|
4
|
+
from typing import Literal, Iterator, Optional
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
7
|
from ..transformers.ff import GatedLinearUnit, get_activation_layer
|
@@ -188,7 +188,7 @@ class MrlActorModel(nn.Module):
|
|
188
188
|
list(self.memory_attention_parameters())
|
189
189
|
))
|
190
190
|
|
191
|
-
def moe_router_loss(self):
|
191
|
+
def moe_router_loss(self) -> Optional[torch.Tensor]:
|
192
192
|
if self.encoder.model.use_moe and self.decoder.model.use_moe:
|
193
193
|
return (self.encoder.model.moe_router_loss() + self.decoder.model.moe_router_loss()) / 2
|
194
194
|
elif self.encoder.model.use_moe:
|
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
|
|
41
41
|
use_memory_warmup: Optional[bool]
|
42
42
|
debug_mode: Optional[bool]
|
43
43
|
debug_interval: Optional[int]
|
44
|
+
clamp_logits: Optional[bool]
|
44
45
|
|
45
46
|
|
46
47
|
class MrlStrategy(Enum):
|
@@ -152,6 +153,7 @@ class MRLTrainer:
|
|
152
153
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
153
154
|
self.debug_mode = config.get('debug_mode', False)
|
154
155
|
self.debug_interval = config.get('debug_interval', 10)
|
156
|
+
self.clamp_logits = config.get('clamp_logits', False)
|
155
157
|
# Internal update epochs config
|
156
158
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
157
159
|
self.update_epochs = self.shared_update_epochs
|
@@ -589,12 +591,16 @@ class MRLTrainer:
|
|
589
591
|
actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
|
590
592
|
|
591
593
|
router_loss = actor.moe_router_loss()
|
594
|
+
if torch.isnan(router_loss).any():
|
595
|
+
print("NaN detected in router loss")
|
592
596
|
if router_loss is not None:
|
593
597
|
return main_loss + self.moe_aux_loss_scale * router_loss
|
594
598
|
else:
|
595
599
|
return main_loss
|
596
600
|
|
597
|
-
def _log_gradients(self):
|
601
|
+
def _log_gradients(self, logits: torch.Tensor):
|
602
|
+
print(
|
603
|
+
f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
|
598
604
|
encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
|
599
605
|
decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
|
600
606
|
mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
|
@@ -611,8 +617,11 @@ class MRLTrainer:
|
|
611
617
|
enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
|
612
618
|
print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
|
613
619
|
|
614
|
-
|
615
|
-
print(f"Encoder
|
620
|
+
enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
|
621
|
+
print(f"Encoder self-att mean norm: {(sum(enc_self_att_norms) / len(enc_self_att_norms)):.6f}, all: {enc_self_att_norms}")
|
622
|
+
|
623
|
+
enc_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
|
624
|
+
print(f"Encoder cross-att mean norm: {(sum(enc_att_norms) / len(enc_att_norms)):.6f}, all: {enc_att_norms}")
|
616
625
|
|
617
626
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
618
627
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
@@ -633,6 +642,8 @@ class MRLTrainer:
|
|
633
642
|
pad_token_id=self.pad_token_id)
|
634
643
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
635
644
|
action=MrlActorAction.DECODE)
|
645
|
+
if self.clamp_logits:
|
646
|
+
logits = logits.clamp(min=-20.0, max=20.0)
|
636
647
|
# 4.2 Calculate policy loss with selected algorithm
|
637
648
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
|
638
649
|
advantages)
|
@@ -643,9 +654,9 @@ class MRLTrainer:
|
|
643
654
|
# 4.4 Unscale and clip gradient norms
|
644
655
|
self.scaler.unscale_(self.optimizer)
|
645
656
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
646
|
-
error_if_nonfinite=
|
657
|
+
error_if_nonfinite=self.debug_mode)
|
647
658
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
648
|
-
self._log_gradients()
|
659
|
+
self._log_gradients(logits)
|
649
660
|
# 4.5 Run scaled optimization step
|
650
661
|
self.scaler.step(self.optimizer)
|
651
662
|
self.scaler.update()
|
@@ -655,6 +666,8 @@ class MRLTrainer:
|
|
655
666
|
pad_token_id=self.pad_token_id)
|
656
667
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
657
668
|
action=MrlActorAction.DECODE)
|
669
|
+
if self.clamp_logits:
|
670
|
+
logits = logits.clamp(min=-20.0, max=20.0)
|
658
671
|
# 4.2 Calculate policy loss with selected algorithm
|
659
672
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
|
660
673
|
policy_loss = self._moe_aux_loss(policy_loss)
|
@@ -662,9 +675,9 @@ class MRLTrainer:
|
|
662
675
|
policy_loss.backward(retain_graph=True)
|
663
676
|
# 4.4 Clip gradient norms
|
664
677
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
665
|
-
error_if_nonfinite=
|
678
|
+
error_if_nonfinite=self.debug_mode)
|
666
679
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
667
|
-
self._log_gradients()
|
680
|
+
self._log_gradients(logits)
|
668
681
|
# 4.5 Run scaled optimization step
|
669
682
|
self.optimizer.step()
|
670
683
|
# 5. Get float loss value for callbacks/writer
|
@@ -102,7 +102,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
102
102
|
residual = x
|
103
103
|
if not self.use_post_norm:
|
104
104
|
x = self.norm1(x)
|
105
|
+
if torch.isnan(x).any():
|
106
|
+
print("NaN detected in pre-norm (self-attention) output")
|
105
107
|
x = self.attention(x, x, x, mask=mask)
|
108
|
+
if torch.isnan(x).any():
|
109
|
+
print("NaN detected in self-attention output")
|
106
110
|
x = residual + x
|
107
111
|
if self.use_post_norm:
|
108
112
|
x = self.norm1(x)
|
@@ -110,11 +114,18 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
114
|
residual = x
|
111
115
|
if not self.use_post_norm:
|
112
116
|
x = self.norm2(x)
|
117
|
+
if torch.isnan(x).any():
|
118
|
+
print("NaN detected in pre-norm (cross-attention) output")
|
113
119
|
|
114
|
-
|
115
|
-
|
120
|
+
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
|
121
|
+
if mask is not None else None
|
122
|
+
|
123
|
+
if torch.isnan(stm).any():
|
124
|
+
print("NaN detected in STM cross-attention input")
|
116
125
|
|
117
126
|
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
127
|
+
if torch.isnan(x).any():
|
128
|
+
print("NaN detected in cross-attention output")
|
118
129
|
x = residual + x
|
119
130
|
if self.use_post_norm:
|
120
131
|
x = self.norm2(x)
|
@@ -123,7 +134,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
123
134
|
residual = x
|
124
135
|
if not self.use_post_norm:
|
125
136
|
x = self.norm3(x)
|
137
|
+
if torch.isnan(x).any():
|
138
|
+
print("NaN detected in pre-norm (ff) output")
|
126
139
|
x = self.ff(x)
|
140
|
+
if torch.isnan(x).any():
|
141
|
+
print("NaN detected in ff output")
|
127
142
|
x = residual + x
|
128
143
|
if self.use_post_norm:
|
129
144
|
x = self.norm3(x)
|
@@ -58,6 +58,15 @@ class ReactiveTransformerBase(nn.Module):
|
|
58
58
|
else:
|
59
59
|
return None
|
60
60
|
|
61
|
+
def _handle_layer(self, i: int, x: torch.Tensor, mask: torch.Tensor = None, is_shared: bool = False):
|
62
|
+
stm_layer_idx = i if is_shared else i + self.num_shared_layers
|
63
|
+
layer_stm = self.stm(stm_layer_idx)
|
64
|
+
# expand layer STM to batch size, if it's not in batch mode
|
65
|
+
if layer_stm.size(0) == 1:
|
66
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
67
|
+
layer = self.shared_layers[i] if is_shared else self.layers[i]
|
68
|
+
return layer(x, layer_stm, mask=mask)
|
69
|
+
|
61
70
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
62
71
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
63
72
|
x = self.embedding(x)
|
@@ -84,6 +93,8 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
84
93
|
|
85
94
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
86
95
|
x = super().forward(x) # apply embeddings
|
96
|
+
if torch.isnan(x).any():
|
97
|
+
print("NaN detected in decoder embedding output")
|
87
98
|
seq_len = x.size(1)
|
88
99
|
if not self.use_flash_attention and self.use_relative_embedding:
|
89
100
|
mask = create_causal_mask(seq_len, device=x.device)
|
@@ -96,18 +107,12 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
96
107
|
# Process shared layers
|
97
108
|
if self.shared_layers is not None:
|
98
109
|
for i in range(self.num_shared_layers):
|
99
|
-
|
100
|
-
# expand layer STM to batch size, if it's not in batch mode
|
101
|
-
if layer_stm.size(0) == 1:
|
102
|
-
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
103
|
-
x = self.shared_layers[i](x, layer_stm, mask=mask)
|
110
|
+
x = self._handle_layer(i, x, mask=mask, is_shared=True)
|
104
111
|
# Process own layers
|
105
112
|
for i in range(self.num_own_layers):
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
110
|
-
x = self.layers[i](x, layer_stm, mask=mask)
|
113
|
+
x = self._handle_layer(i, x, mask=mask)
|
114
|
+
if torch.isnan(x).any():
|
115
|
+
print(f"NaN detected in {i}. decoder layer output")
|
111
116
|
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
112
117
|
|
113
118
|
|
@@ -116,6 +121,8 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
116
121
|
|
117
122
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
118
123
|
x = super().forward(x) # apply embeddings
|
124
|
+
if torch.isnan(x).any():
|
125
|
+
print("NaN detected in encoder embedding output")
|
119
126
|
if attention_mask is not None:
|
120
127
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
121
128
|
|
@@ -123,19 +130,13 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
123
130
|
# Process shared layers
|
124
131
|
if self.shared_layers is not None:
|
125
132
|
for i in range(self.num_shared_layers):
|
126
|
-
|
127
|
-
# expand layer STM to batch size, if it's not in batch mode
|
128
|
-
if layer_stm.size(0) == 1:
|
129
|
-
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
130
|
-
x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
|
133
|
+
x = self._handle_layer(i, x, mask=attention_mask, is_shared=True)
|
131
134
|
hidden_states.append(x)
|
132
135
|
# Process own layers
|
133
136
|
for i in range(self.num_own_layers):
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
138
|
-
x = self.layers[i](x, layer_stm, mask=attention_mask)
|
137
|
+
x = self._handle_layer(i, x, mask=attention_mask)
|
138
|
+
if torch.isnan(x).any():
|
139
|
+
print(f"NaN detected in {i}. encoder layer output")
|
139
140
|
hidden_states.append(x)
|
140
141
|
return x, torch.stack(hidden_states)
|
141
142
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|