rxnn 0.2.68__tar.gz → 0.2.70__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.68 → rxnn-0.2.70}/PKG-INFO +1 -1
  2. {rxnn-0.2.68 → rxnn-0.2.70}/pyproject.toml +1 -1
  3. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/attention.py +8 -0
  4. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/dataset.py +14 -5
  5. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/models.py +2 -2
  6. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/mrl.py +20 -7
  7. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/layers.py +17 -2
  8. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/models.py +21 -20
  9. {rxnn-0.2.68 → rxnn-0.2.70}/LICENSE +0 -0
  10. {rxnn-0.2.68 → rxnn-0.2.70}/README.md +0 -0
  11. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/.DS_Store +0 -0
  12. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/__init__.py +0 -0
  13. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/__init__.py +0 -0
  14. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/attention.py +0 -0
  15. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/models.py +0 -0
  16. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/experimental/moe.py +0 -0
  17. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/__init__.py +0 -0
  18. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/rxt/models.py +0 -0
  22. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/base.py +0 -0
  24. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/bml.py +0 -0
  25. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/callbacks.py +0 -0
  26. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/ddp.py +0 -0
  27. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/reward.py +0 -0
  28. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/rl.py +0 -0
  29. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/scheduler.py +0 -0
  30. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/tokenizer.py +0 -0
  31. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/training/utils.py +0 -0
  32. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/__init__.py +0 -0
  33. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/attention.py +0 -0
  34. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/ff.py +0 -0
  35. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.68 → rxnn-0.2.70}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.68
3
+ Version: 0.2.70
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.68"
7
+ version = "0.2.70"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -64,6 +64,8 @@ class StmMemoryAttention(nn.Module):
64
64
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
65
65
  encoded_layer_data = x[i]
66
66
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
+ if torch.isnan(normalized_layer_stm).any():
68
+ print(f"NaN detected in {i} layer memory norm output")
67
69
 
68
70
  if self.debug_mode and self.training:
69
71
  if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
@@ -72,7 +74,13 @@ class StmMemoryAttention(nn.Module):
72
74
  else:
73
75
  self.debug_step += 1
74
76
 
77
+ if torch.isnan(encoded_layer_data).any():
78
+ print(f"NaN detected in {i} layer encoded data input")
79
+
75
80
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
+ if torch.isnan(new_layer_stm).any():
82
+ print(f"NaN detected in {i} layer memory attention output")
83
+
76
84
  if self.use_gated_residual:
77
85
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
78
86
  else:
@@ -4,7 +4,7 @@ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
4
4
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
5
  from .tokenizer import load_tokenizer_from_hf_hub
6
6
 
7
- from typing import Union, TypedDict, Optional, TypeAlias, Any
7
+ from typing import Union, TypedDict, Optional, TypeAlias, Any, Literal
8
8
 
9
9
 
10
10
  class BaseDataset(Dataset):
@@ -854,8 +854,8 @@ class EncoderSftDataset(BaseInteractionDataset):
854
854
  'labels': labels
855
855
  }
856
856
 
857
-
858
- MrlDataItem: TypeAlias = dict[str, Union[dict[str, torch.Tensor], list[dict[str, dict[str, torch.Tensor]]]]]
857
+ ItemFields: TypeAlias = Literal['input_ids', 'attention_mask']
858
+ MrlDataItem: TypeAlias = dict[str, Union[dict[ItemFields, torch.Tensor], list[dict[str, dict[ItemFields, torch.Tensor]]]]]
859
859
 
860
860
 
861
861
  class MrlCurriculumDataset(Dataset):
@@ -1031,7 +1031,7 @@ class MrlCurriculumDataset(Dataset):
1031
1031
  """Collate function for MRL curriculum dataset with nested interactions"""
1032
1032
 
1033
1033
  def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> \
1034
- dict[str, dict[str, torch.Tensor]]:
1034
+ dict[str, dict[ItemFields, torch.Tensor]]:
1035
1035
  """Helper to collate a batch of interactions"""
1036
1036
  return {
1037
1037
  'query': {
@@ -1047,13 +1047,22 @@ class MrlCurriculumDataset(Dataset):
1047
1047
  batch_interactions = [x['interactions'] for x in batch]
1048
1048
  transposed_interactions = list(zip(*batch_interactions))
1049
1049
 
1050
- return {
1050
+ def has_nans(tensor: dict[ItemFields, torch.Tensor]) -> bool:
1051
+ return torch.isnan(tensor['input_ids']).any().item() or torch.isnan(tensor['attention_mask']).any().item()
1052
+
1053
+ results: MrlDataItem = {
1051
1054
  **collate_interaction_batch(batch), # Collate initial query and answer
1052
1055
  'interactions': [
1053
1056
  collate_interaction_batch(step_batch) for step_batch in transposed_interactions
1054
1057
  ]
1055
1058
  }
1056
1059
 
1060
+ assert not has_nans(results['query']), "NaN in query"
1061
+ assert not has_nans(results['answer']), "NaN in answer"
1062
+ assert not any([(has_nans(item['query']) or has_nans(item['answer'])) for item in results['interactions']]), "NaN in interactions"
1063
+
1064
+ return results
1065
+
1057
1066
 
1058
1067
  class MrlDatasetItem(TypedDict):
1059
1068
  steps: int
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from enum import Enum
4
- from typing import Literal, Iterator
4
+ from typing import Literal, Iterator, Optional
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
7
  from ..transformers.ff import GatedLinearUnit, get_activation_layer
@@ -188,7 +188,7 @@ class MrlActorModel(nn.Module):
188
188
  list(self.memory_attention_parameters())
189
189
  ))
190
190
 
191
- def moe_router_loss(self):
191
+ def moe_router_loss(self) -> Optional[torch.Tensor]:
192
192
  if self.encoder.model.use_moe and self.decoder.model.use_moe:
193
193
  return (self.encoder.model.moe_router_loss() + self.decoder.model.moe_router_loss()) / 2
194
194
  elif self.encoder.model.use_moe:
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
41
41
  use_memory_warmup: Optional[bool]
42
42
  debug_mode: Optional[bool]
43
43
  debug_interval: Optional[int]
44
+ clamp_logits: Optional[bool]
44
45
 
45
46
 
46
47
  class MrlStrategy(Enum):
@@ -152,6 +153,7 @@ class MRLTrainer:
152
153
  self.use_memory_warmup = config.get('use_memory_warmup', False)
153
154
  self.debug_mode = config.get('debug_mode', False)
154
155
  self.debug_interval = config.get('debug_interval', 10)
156
+ self.clamp_logits = config.get('clamp_logits', False)
155
157
  # Internal update epochs config
156
158
  self.shared_update_epochs = config.get('update_epochs', 10)
157
159
  self.update_epochs = self.shared_update_epochs
@@ -589,12 +591,16 @@ class MRLTrainer:
589
591
  actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
590
592
 
591
593
  router_loss = actor.moe_router_loss()
594
+ if torch.isnan(router_loss).any():
595
+ print("NaN detected in router loss")
592
596
  if router_loss is not None:
593
597
  return main_loss + self.moe_aux_loss_scale * router_loss
594
598
  else:
595
599
  return main_loss
596
600
 
597
- def _log_gradients(self):
601
+ def _log_gradients(self, logits: torch.Tensor):
602
+ print(
603
+ f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
598
604
  encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
599
605
  decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
600
606
  mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
@@ -611,8 +617,11 @@ class MRLTrainer:
611
617
  enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
612
618
  print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
613
619
 
614
- enc_ff_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
615
- print(f"Encoder cross-att mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
620
+ enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
621
+ print(f"Encoder self-att mean norm: {(sum(enc_self_att_norms) / len(enc_self_att_norms)):.6f}, all: {enc_self_att_norms}")
622
+
623
+ enc_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
624
+ print(f"Encoder cross-att mean norm: {(sum(enc_att_norms) / len(enc_att_norms)):.6f}, all: {enc_att_norms}")
616
625
 
617
626
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
618
627
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
@@ -633,6 +642,8 @@ class MRLTrainer:
633
642
  pad_token_id=self.pad_token_id)
634
643
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
635
644
  action=MrlActorAction.DECODE)
645
+ if self.clamp_logits:
646
+ logits = logits.clamp(min=-20.0, max=20.0)
636
647
  # 4.2 Calculate policy loss with selected algorithm
637
648
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
638
649
  advantages)
@@ -643,9 +654,9 @@ class MRLTrainer:
643
654
  # 4.4 Unscale and clip gradient norms
644
655
  self.scaler.unscale_(self.optimizer)
645
656
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
646
- error_if_nonfinite=False)
657
+ error_if_nonfinite=self.debug_mode)
647
658
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
648
- self._log_gradients()
659
+ self._log_gradients(logits)
649
660
  # 4.5 Run scaled optimization step
650
661
  self.scaler.step(self.optimizer)
651
662
  self.scaler.update()
@@ -655,6 +666,8 @@ class MRLTrainer:
655
666
  pad_token_id=self.pad_token_id)
656
667
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
657
668
  action=MrlActorAction.DECODE)
669
+ if self.clamp_logits:
670
+ logits = logits.clamp(min=-20.0, max=20.0)
658
671
  # 4.2 Calculate policy loss with selected algorithm
659
672
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
660
673
  policy_loss = self._moe_aux_loss(policy_loss)
@@ -662,9 +675,9 @@ class MRLTrainer:
662
675
  policy_loss.backward(retain_graph=True)
663
676
  # 4.4 Clip gradient norms
664
677
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
665
- error_if_nonfinite=False)
678
+ error_if_nonfinite=self.debug_mode)
666
679
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
667
- self._log_gradients()
680
+ self._log_gradients(logits)
668
681
  # 4.5 Run scaled optimization step
669
682
  self.optimizer.step()
670
683
  # 5. Get float loss value for callbacks/writer
@@ -102,7 +102,11 @@ class ReactiveTransformerLayer(nn.Module):
102
102
  residual = x
103
103
  if not self.use_post_norm:
104
104
  x = self.norm1(x)
105
+ if torch.isnan(x).any():
106
+ print("NaN detected in pre-norm (self-attention) output")
105
107
  x = self.attention(x, x, x, mask=mask)
108
+ if torch.isnan(x).any():
109
+ print("NaN detected in self-attention output")
106
110
  x = residual + x
107
111
  if self.use_post_norm:
108
112
  x = self.norm1(x)
@@ -110,11 +114,18 @@ class ReactiveTransformerLayer(nn.Module):
110
114
  residual = x
111
115
  if not self.use_post_norm:
112
116
  x = self.norm2(x)
117
+ if torch.isnan(x).any():
118
+ print("NaN detected in pre-norm (cross-attention) output")
113
119
 
114
- if mask is not None:
115
- mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
120
+ mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
+ if mask is not None else None
122
+
123
+ if torch.isnan(stm).any():
124
+ print("NaN detected in STM cross-attention input")
116
125
 
117
126
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
+ if torch.isnan(x).any():
128
+ print("NaN detected in cross-attention output")
118
129
  x = residual + x
119
130
  if self.use_post_norm:
120
131
  x = self.norm2(x)
@@ -123,7 +134,11 @@ class ReactiveTransformerLayer(nn.Module):
123
134
  residual = x
124
135
  if not self.use_post_norm:
125
136
  x = self.norm3(x)
137
+ if torch.isnan(x).any():
138
+ print("NaN detected in pre-norm (ff) output")
126
139
  x = self.ff(x)
140
+ if torch.isnan(x).any():
141
+ print("NaN detected in ff output")
127
142
  x = residual + x
128
143
  if self.use_post_norm:
129
144
  x = self.norm3(x)
@@ -58,6 +58,15 @@ class ReactiveTransformerBase(nn.Module):
58
58
  else:
59
59
  return None
60
60
 
61
+ def _handle_layer(self, i: int, x: torch.Tensor, mask: torch.Tensor = None, is_shared: bool = False):
62
+ stm_layer_idx = i if is_shared else i + self.num_shared_layers
63
+ layer_stm = self.stm(stm_layer_idx)
64
+ # expand layer STM to batch size, if it's not in batch mode
65
+ if layer_stm.size(0) == 1:
66
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
67
+ layer = self.shared_layers[i] if is_shared else self.layers[i]
68
+ return layer(x, layer_stm, mask=mask)
69
+
61
70
  def forward(self, x: torch.Tensor) -> torch.Tensor:
62
71
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
63
72
  x = self.embedding(x)
@@ -84,6 +93,8 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
84
93
 
85
94
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
86
95
  x = super().forward(x) # apply embeddings
96
+ if torch.isnan(x).any():
97
+ print("NaN detected in decoder embedding output")
87
98
  seq_len = x.size(1)
88
99
  if not self.use_flash_attention and self.use_relative_embedding:
89
100
  mask = create_causal_mask(seq_len, device=x.device)
@@ -96,18 +107,12 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
96
107
  # Process shared layers
97
108
  if self.shared_layers is not None:
98
109
  for i in range(self.num_shared_layers):
99
- layer_stm = self.stm(i)
100
- # expand layer STM to batch size, if it's not in batch mode
101
- if layer_stm.size(0) == 1:
102
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
103
- x = self.shared_layers[i](x, layer_stm, mask=mask)
110
+ x = self._handle_layer(i, x, mask=mask, is_shared=True)
104
111
  # Process own layers
105
112
  for i in range(self.num_own_layers):
106
- layer_stm = self.stm(i)
107
- # expand layer STM to batch size, if it's not in batch mode
108
- if layer_stm.size(0) == 1:
109
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
110
- x = self.layers[i](x, layer_stm, mask=mask)
113
+ x = self._handle_layer(i, x, mask=mask)
114
+ if torch.isnan(x).any():
115
+ print(f"NaN detected in {i}. decoder layer output")
111
116
  return self.head(self.head_norm(x) if self.use_head_norm else x)
112
117
 
113
118
 
@@ -116,6 +121,8 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
116
121
 
117
122
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
118
123
  x = super().forward(x) # apply embeddings
124
+ if torch.isnan(x).any():
125
+ print("NaN detected in encoder embedding output")
119
126
  if attention_mask is not None:
120
127
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
121
128
 
@@ -123,19 +130,13 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
123
130
  # Process shared layers
124
131
  if self.shared_layers is not None:
125
132
  for i in range(self.num_shared_layers):
126
- layer_stm = self.stm(i)
127
- # expand layer STM to batch size, if it's not in batch mode
128
- if layer_stm.size(0) == 1:
129
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
130
- x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
133
+ x = self._handle_layer(i, x, mask=attention_mask, is_shared=True)
131
134
  hidden_states.append(x)
132
135
  # Process own layers
133
136
  for i in range(self.num_own_layers):
134
- layer_stm = self.stm(i)
135
- # expand layer STM to batch size, if it's not in batch mode
136
- if layer_stm.size(0) == 1:
137
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
138
- x = self.layers[i](x, layer_stm, mask=attention_mask)
137
+ x = self._handle_layer(i, x, mask=attention_mask)
138
+ if torch.isnan(x).any():
139
+ print(f"NaN detected in {i}. encoder layer output")
139
140
  hidden_states.append(x)
140
141
  return x, torch.stack(hidden_states)
141
142
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes