rxnn 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/norm.py CHANGED
@@ -7,10 +7,11 @@ class AdaptivePositionalMemoryNorm(nn.Module):
7
7
  self,
8
8
  num_slots: int,
9
9
  dim: int,
10
- decay: float = 0.99,
10
+ decay: float = 0.9,
11
11
  use_scale: bool = True,
12
12
  use_gate: bool = True,
13
- init_gate: float = -4.0
13
+ init_gate: float = -2.0,
14
+ per_dim_scale: bool = False,
14
15
  ):
15
16
  super(AdaptivePositionalMemoryNorm, self).__init__()
16
17
  self.use_gate = use_gate
@@ -20,7 +21,8 @@ class AdaptivePositionalMemoryNorm(nn.Module):
20
21
  self.eps = 1e-6
21
22
 
22
23
  # Learnable parameters
23
- self.scale = nn.Parameter(torch.ones(num_slots, dim)) if use_scale else None
24
+ scale_shape = (num_slots, 1) if not per_dim_scale else (dim,)
25
+ self.scale = nn.Parameter(torch.ones(*scale_shape)) if use_scale else None
24
26
  self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
25
27
 
26
28
  # EMA buffers
@@ -28,7 +30,7 @@ class AdaptivePositionalMemoryNorm(nn.Module):
28
30
 
29
31
  # Initialize parameters
30
32
  if self.scale is not None:
31
- nn.init.normal_(self.scale, mean=1.0, std=0.01)
33
+ nn.init.normal_(self.scale, mean=1.0, std=0.1)
32
34
 
33
35
  def forward(self, x: torch.Tensor) -> torch.Tensor:
34
36
  # Calculate current RMS per slot
@@ -45,7 +47,7 @@ class AdaptivePositionalMemoryNorm(nn.Module):
45
47
 
46
48
  # Apply learned scale per slot
47
49
  if self.scale is not None:
48
- x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, dim]
50
+ x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, 1] or [dim]
49
51
 
50
52
  # Apply gating mechanism
51
53
  if self.use_gate:
@@ -148,24 +150,26 @@ class MemoryNormConfig(TypedDict):
148
150
  use_gate: bool
149
151
  init_gate: float
150
152
  init_scale: float
153
+ per_dim_scale: bool
151
154
 
152
155
  def init_memory_norm(
153
156
  norm_type: str,
154
157
  dim: int,
155
158
  num_slots: int = None,
156
- decay: float = 0.99,
159
+ decay: float = 0.9,
157
160
  use_scale: bool = True,
158
161
  use_gate: bool = True,
159
- init_gate: float = -4.0,
162
+ init_gate: float = -2.0,
160
163
  init_scale: float = 1.0,
164
+ per_dim_scale: bool = False,
161
165
  ) -> nn.Module:
162
- assert norm_type in ["layer", "rms", "adaptive", "positional"]
163
- if norm_type == "layer":
166
+ assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
167
+ if norm_type == 'layer':
164
168
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
165
- elif norm_type == "rms":
169
+ elif norm_type == 'rms':
166
170
  return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
167
- elif norm_type == "adaptive":
171
+ elif norm_type == 'adaptive':
168
172
  return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
169
- elif norm_type == "positional":
170
- return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
173
+ elif norm_type == 'positional':
174
+ return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
171
175
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
rxnn/rxt/models.py CHANGED
@@ -13,6 +13,7 @@ from ..memory.attention import StmMemoryAttention
13
13
  from ..utils import get_model_size
14
14
  from ..experimental.attention import init_experimental_attention
15
15
 
16
+
16
17
  class RxTAlphaComponentConfig(TypedDict):
17
18
  num_layers: int
18
19
  vocab_size: int
@@ -76,8 +77,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
76
77
  assert ff_activation in ['relu', 'gelu',
77
78
  'swish', 'silu', 'linear',
78
79
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
79
- assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
80
- assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
80
+ assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
81
+ 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
82
+ assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
83
+ 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
81
84
 
82
85
  embedding = nn.Embedding(vocab_size, embed_dim)
83
86
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -92,20 +95,25 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
92
95
  else:
93
96
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
94
97
  use_flash_attention=use_flash_attention, dropout=att_dropout,
95
- max_seq_len=seq_len, is_causal=is_causal, num_experts=att_experts,
98
+ max_seq_len=seq_len, is_causal=is_causal,
99
+ num_experts=att_experts,
96
100
  num_query_experts=att_query_experts,
97
101
  num_query_groups=att_query_groups)
98
102
 
99
103
  if cross_att_type in ['mha', 'gqa', 'mqa']:
100
104
  cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
101
- use_flash_attention=use_flash_attention, dropout=att_dropout,
102
- max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
105
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
106
+ max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
103
107
  else:
104
- cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type, cross_att_groups or att_groups, rope=rope,
105
- use_flash_attention=use_flash_attention, dropout=att_dropout,
106
- max_seq_len=seq_len, is_causal=is_causal, num_experts=att_experts,
107
- num_query_experts=att_query_experts,
108
- num_query_groups=cross_att_query_groups or att_query_groups, rope_only_for_query=True)
108
+ cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
109
+ cross_att_groups or att_groups, rope=rope,
110
+ use_flash_attention=use_flash_attention,
111
+ dropout=att_dropout,
112
+ max_seq_len=seq_len, is_causal=is_causal,
113
+ num_experts=att_experts,
114
+ num_query_experts=att_query_experts,
115
+ num_query_groups=cross_att_query_groups or att_query_groups,
116
+ rope_only_for_query=True)
109
117
 
110
118
  layers = nn.ModuleList([
111
119
  ReactiveTransformerLayer(
@@ -137,6 +145,12 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
137
145
  def load_shared_memory(self, stm: ShortTermMemory):
138
146
  self.model.stm = stm
139
147
 
148
+ def memory_parameters(self) -> list[nn.Parameter]:
149
+ return self.model.memory_parameters()
150
+
151
+ def not_memory_parameters(self) -> list[nn.Parameter]:
152
+ return self.model.not_memory_parameters()
153
+
140
154
  def freeze_without_memory(self, unfreeze_norms: bool = True):
141
155
  for param in self.model.parameters():
142
156
  param.requires_grad_(False)
@@ -211,20 +225,9 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
211
225
  return self.model(x, attention_mask=attention_mask)
212
226
 
213
227
 
214
- def build_rxt_alpha_for_pretraining(
215
- encoder_config: RxTAlphaComponentConfig,
216
- decoder_config: RxTAlphaComponentConfig,
217
- ) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
218
- encoder = RxTAlphaEncoder(**encoder_config)
219
- decoder = RxTAlphaDecoder(**decoder_config)
220
-
221
- encoder.load_shared_memory(decoder.model.stm)
222
- encoder.load_shared_embedding(decoder.model.embedding)
223
-
224
- return encoder, decoder
225
-
226
228
  class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
227
229
  """RxT-Alpha (Reactive Transformer) memory attention model"""
230
+
228
231
  def __init__(
229
232
  self,
230
233
  num_layers: int = 12,
@@ -234,17 +237,21 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
234
237
  stm_size: int = 1024,
235
238
  use_flash_attention: bool = False,
236
239
  att_dropout: float = 0.0,
237
- norm_type: str = 'rms',
238
240
  att_groups: int = 1,
239
241
  att_type: str = 'sqa',
240
242
  att_experts: int = None,
241
243
  att_query_experts: int = None,
242
244
  att_query_groups: int = None,
245
+ norm_type: str = 'rms',
246
+ norm_init_gate: float = -2.0,
247
+ norm_per_dim_scale: bool = False,
248
+ norm_decay: float = 0.9,
243
249
  **kwargs,
244
250
  ):
245
251
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
246
252
 
247
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
253
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
254
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
248
255
 
249
256
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
250
257
  stm = ShortTermMemory(num_layers, embed_dim, stm_size)
@@ -256,11 +263,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
256
263
  else:
257
264
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
258
265
  use_flash_attention=use_flash_attention, dropout=att_dropout,
259
- max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
266
+ max_seq_len=seq_len, is_causal=False,
267
+ num_experts=att_experts,
260
268
  num_query_experts=att_query_experts,
261
269
  num_query_groups=att_query_groups, rope_only_for_keys=True)
262
270
 
263
- memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
271
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
272
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
273
+ for _ in range(num_layers)])
264
274
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
265
275
  self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
266
276
 
@@ -283,4 +293,3 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
283
293
 
284
294
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
285
295
  return self.model(x, attention_mask=attention_mask)
286
-
rxnn/training/models.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from enum import Enum
4
- from typing import Literal
4
+ from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
7
 
@@ -75,7 +75,7 @@ class MrlActorModel(nn.Module):
75
75
  self.decoder = decoder
76
76
  self.memory_attention = memory_attention
77
77
 
78
- def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
78
+ def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
79
79
  """Freeze encoder/decoder except memory-related layers."""
80
80
  if self.encoder.freeze_without_memory is not None:
81
81
  self.encoder.freeze_without_memory(unfreeze_norms=True)
@@ -124,6 +124,29 @@ class MrlActorModel(nn.Module):
124
124
  def reset_memory(self):
125
125
  self.memory_attention.reset_memory()
126
126
 
127
+ def memory_parameters(self) -> list[nn.Parameter]:
128
+ return list(set(
129
+ self.encoder.memory_parameters() +
130
+ self.decoder.memory_parameters() +
131
+ self.memory_attention.parameters()
132
+ ))
133
+
134
+ def memory_cross_attention_parameters(self) -> list[nn.Parameter]:
135
+ return list(set(
136
+ self.encoder.memory_parameters() +
137
+ self.decoder.memory_parameters()
138
+ ))
139
+
140
+ def memory_attention_parameters(self) -> Iterator[nn.Parameter]:
141
+ return self.memory_attention.parameters()
142
+
143
+
144
+ def not_memory_parameters(self) -> list[nn.Parameter]:
145
+ return list(set(
146
+ self.encoder.not_memory_parameters() +
147
+ self.decoder.not_memory_parameters()
148
+ ))
149
+
127
150
  def unique_parameters(self):
128
151
  return list(set(
129
152
  list(self.encoder.parameters()) +
rxnn/training/mrl.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict, Union
6
+ from typing import Optional, TypedDict, Union, TypeAlias, Literal
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -17,6 +17,8 @@ from .models import MrlActorAction, MrlActorModel, MrlCriticModel
17
17
 
18
18
  class MrlConfig(TypedDict):
19
19
  lr: float
20
+ separate_memory_lr: Optional[bool]
21
+ memory_lr: Optional[float]
20
22
  critic_lr: float
21
23
  max_seq_len: int
22
24
  critic_max_len: int
@@ -29,6 +31,8 @@ class MrlStrategy(Enum):
29
31
  MULTI_STEP_STRATEGY = 2
30
32
  LONG_RANGE_STRATEGY = 3
31
33
 
34
+ UnfreezeItem = Union[int, tuple[int, float]]
35
+ UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
32
36
 
33
37
  class CurriculumConfig(TypedDict):
34
38
  steps: int
@@ -37,12 +41,14 @@ class CurriculumConfig(TypedDict):
37
41
  eval_dataset: Optional[MrlCurriculumDataset]
38
42
  callbacks: Optional[list[MrlTrainerCallback]]
39
43
  strategy: MrlStrategy
40
- unfreeze_epoch: Optional[Union[int, tuple[int, int, int, int]]]
44
+ unfreeze_epoch: Optional[UnfreezeEpochsStrategy]
41
45
  random_resets: Optional[bool]
42
46
  random_resets_from: Optional[int]
43
47
  random_resets_ratio: Optional[float]
44
48
  reward_model: Optional[MrlRewardModel]
49
+ separate_memory_lr: Optional[bool]
45
50
  lr: Optional[float]
51
+ memory_lr: Optional[float]
46
52
  critic_lr: Optional[float]
47
53
  weight_decay: Optional[float]
48
54
  critic_weight_decay: Optional[float]
@@ -84,6 +90,7 @@ class MRLTrainer:
84
90
  use_amp: bool = False,
85
91
  dtype: torch.dtype = torch.float32,
86
92
  callbacks: list[MrlTrainerCallback] = None,
93
+
87
94
  ):
88
95
  """
89
96
  Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
@@ -123,15 +130,27 @@ class MRLTrainer:
123
130
  self.use_amp = use_amp
124
131
  self.dtype = dtype
125
132
 
126
- self.base_optim_config = {
127
- 'lr': config.get('lr', 3e-4),
128
- 'critic_lr': config.get('critic_lr', 1e-4),
129
- 'weight_decay': config.get('weight_decay', 0.01),
130
- 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
131
- }
133
+ self.separate_memory_lr = config.get('separate_memory_lr', False)
134
+
135
+ if self.separate_memory_lr:
136
+ self.base_optim_config = {
137
+ 'lr': config.get('lr', 3e-4),
138
+ 'memory_lr': config.get('memory_lr', 5e-4),
139
+ 'critic_lr': config.get('critic_lr', 1e-4),
140
+ 'weight_decay': config.get('weight_decay', 0.01),
141
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
142
+ }
143
+ else:
144
+ self.base_optim_config = {
145
+ 'lr': config.get('lr', 3e-4),
146
+ 'critic_lr': config.get('critic_lr', 1e-4),
147
+ 'weight_decay': config.get('weight_decay', 0.01),
148
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
149
+ }
150
+
151
+ self.optim_config = self.base_optim_config
132
152
 
133
- # Optimizers
134
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
153
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
135
154
 
136
155
  self.scaler = torch.amp.GradScaler() if self.use_amp else None
137
156
  self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
@@ -158,18 +177,34 @@ class MRLTrainer:
158
177
  self.global_epoch = 0
159
178
  self.global_epochs_count = 0
160
179
 
161
- def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
162
- optimizer = torch.optim.AdamW(
163
- self.actor.unique_parameters(),
164
- lr=lr,
165
- weight_decay=weight_decay,
166
- )
180
+ def _init_optimizers(
181
+ self,
182
+ lr: float,
183
+ critic_lr: float,
184
+ weight_decay: float,
185
+ critic_weight_decay: float,
186
+ memory_lr: Optional[float] = None,
187
+ ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
188
+ if memory_lr is not None:
189
+ optimizer = torch.optim.AdamW([
190
+ { 'params': self.actor.not_memory_parameters(), 'lr': lr },
191
+ { 'params': self.actor.memory_parameters(), 'lr': memory_lr },
192
+ ],
193
+ weight_decay=weight_decay,
194
+ )
195
+ else:
196
+ optimizer = torch.optim.AdamW(
197
+ self.actor.unique_parameters(),
198
+ lr=lr,
199
+ weight_decay=weight_decay,
200
+ )
167
201
 
168
202
  critic_optimizer = torch.optim.AdamW(
169
203
  self.critic.parameters(),
170
204
  lr=critic_lr,
171
205
  weight_decay=critic_weight_decay,
172
206
  )
207
+
173
208
  return optimizer, critic_optimizer
174
209
 
175
210
 
@@ -712,7 +747,7 @@ class MRLTrainer:
712
747
 
713
748
  return should_stop_stage
714
749
 
715
- def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, int], tuple[bool, int, float]]:
750
+ def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
716
751
  # 1. Set common fields based on config
717
752
  self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
718
753
  self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
@@ -722,13 +757,29 @@ class MRLTrainer:
722
757
  self.strategy = config.get('strategy',
723
758
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
724
759
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
725
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
726
- self.optimizer, self.critic_optimizer = self._init_optimizers(
727
- lr=config.get('lr', self.base_optim_config['lr']),
728
- critic_lr=config.get('critic_lr', self.base_optim_config['critic_lr']),
729
- weight_decay=config.get('weight_decay', self.base_optim_config['weight_decay']),
730
- critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay'])
731
- )
760
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
761
+ if config.get('separate_memory_lr', False):
762
+ self.optim_config = {
763
+ 'lr': config.get('lr', self.base_optim_config['lr']),
764
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
765
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
766
+ 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
767
+ 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
768
+ }
769
+ else:
770
+ self.optim_config = {
771
+ 'lr': config.get('lr', self.base_optim_config['lr']),
772
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
773
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
774
+ 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
775
+ }
776
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
777
+ elif self.optim_config != self.base_optim_config:
778
+ self.optim_config = self.base_optim_config
779
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
780
+
781
+
782
+
732
783
 
733
784
  # 2. Get epochs and random resets configs
734
785
  epochs = config.get('epochs', 5) # number of epochs for current stage
@@ -745,6 +796,82 @@ class MRLTrainer:
745
796
 
746
797
  return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
747
798
 
799
+ def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
800
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
801
+ if is_staged_unfreeze:
802
+ update_epoch, fetch_epoch, joint_epoch, all_epoch = unfreeze_epoch
803
+
804
+ if isinstance(update_epoch, tuple):
805
+ switch_epoch, cross_att_lr = update_epoch
806
+ if epoch == switch_epoch:
807
+ self.actor.freeze_components('joint')
808
+ self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
809
+ print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
810
+ elif epoch == update_epoch:
811
+ self.actor.freeze_components('update')
812
+ print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
813
+
814
+ if isinstance(fetch_epoch, tuple):
815
+ switch_epoch, mem_att_lr = fetch_epoch
816
+ if epoch == fetch_epoch:
817
+ self.actor.freeze_components('joint')
818
+ self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
819
+ print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
820
+ elif epoch == fetch_epoch:
821
+ self.actor.freeze_components('fetch')
822
+ print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
823
+
824
+ if isinstance(joint_epoch, tuple):
825
+ switch_epoch, model_lr = joint_epoch
826
+ if epoch == joint_epoch:
827
+ self.actor.unfreeze_components()
828
+ self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
829
+ print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
830
+ elif epoch == joint_epoch:
831
+ self.actor.freeze_components('joint')
832
+ print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
833
+ if epoch == all_epoch:
834
+ self.actor.unfreeze_components()
835
+ self.optimizer = self._init_unfreeze_optimizer('all', 0.)
836
+ print(f"Switching to train 'all' strategy - unfreeze all components")
837
+ elif epoch == unfreeze_epoch:
838
+ self.actor.unfreeze_components()
839
+ print(f"Switching to train 'all' strategy - unfreeze all components")
840
+
841
+ def _init_unfreeze_optimizer(
842
+ self,
843
+ mode: Literal['update', 'fetch', 'joint', 'all'],
844
+ unfreeze_lr: float,
845
+ ) -> torch.optim.Optimizer:
846
+ memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
847
+ model_lr = self.optim_config['lr']
848
+
849
+ if mode == 'update':
850
+ params = [
851
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
852
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
853
+ {'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
854
+ ]
855
+ elif mode == 'fetch':
856
+ params = [
857
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
858
+ {'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
859
+ {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
860
+ ]
861
+ elif mode == 'joint':
862
+ params = [
863
+ {'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
864
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
865
+ ]
866
+ else:
867
+ params = [
868
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
869
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
870
+ ]
871
+
872
+ return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
873
+
874
+
748
875
  def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
749
876
  """Start Memory Reinforcement Learning Curriculum."""
750
877
 
@@ -770,7 +897,11 @@ class MRLTrainer:
770
897
 
771
898
  # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
772
899
  if unfreeze_epoch != 0:
773
- self.actor.freeze_components('both')
900
+ self.actor.freeze_components('joint')
901
+ if isinstance(unfreeze_epoch, tuple):
902
+ print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
903
+ else:
904
+ print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
774
905
 
775
906
  # 5. Setup train DataLoader
776
907
  if self.use_ddp:
@@ -810,21 +941,8 @@ class MRLTrainer:
810
941
  else:
811
942
  self.random_resets_ratio = 1.0
812
943
 
813
- # 11. Unfreeze all components before selected epoch
814
- is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
815
- if is_staged_unfreeze:
816
- update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
817
- if epoch == update_epoch:
818
- self.actor.freeze_components('update')
819
- elif epoch == fetch_epoch:
820
- self.actor.freeze_components('fetch')
821
- elif epoch == both_epoch:
822
- self.actor.freeze_components('both')
823
- elif epoch == all_epoch:
824
- self.actor.unfreeze_components()
825
- else:
826
- if epoch == unfreeze_epoch:
827
- self.actor.unfreeze_components()
944
+ # 11. Apply the unfreeze strategy
945
+ self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
828
946
 
829
947
  # 12. Set epoch for distributed sampler
830
948
  if train_sampler is not None:
@@ -64,6 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
64
64
  for param in self.norm2.parameters():
65
65
  param.requires_grad_(is_trainable)
66
66
 
67
+ def memory_parameters(self) -> list[nn.Parameter]:
68
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
69
+
70
+ def not_memory_parameters(self) -> list[nn.Parameter]:
71
+ memory_params = self.memory_parameters()
72
+ return [param for param in self.parameters() if param not in memory_params]
73
+
67
74
  def update_max_len(self, max_seq_len: int):
68
75
  if self.attention.rope is not None:
69
76
  self.attention.rope.update_max_len(max_seq_len)
@@ -39,6 +39,16 @@ class ReactiveTransformerBase(nn.Module):
39
39
  for i in range(self.num_own_layers):
40
40
  self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
41
41
 
42
+ def memory_parameters(self) -> list[nn.Parameter]:
43
+ own = [param for layer in self.layers for param in layer.memory_parameters()]
44
+ shared = [param for layer in self.shared_layers for param in layer.memory_parameters()] if self.shared_layers else []
45
+ return own + shared
46
+
47
+ def not_memory_parameters(self) -> list[nn.Parameter]:
48
+ own = [param for layer in self.layers for param in layer.not_memory_parameters()]
49
+ shared = [param for layer in self.shared_layers for param in layer.not_memory_parameters()] if self.shared_layers else []
50
+ return own + shared
51
+
42
52
  def moe_router_loss(self):
43
53
  return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
44
54
  self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.25
3
+ Version: 0.2.27
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -6,17 +6,17 @@ rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
9
- rxnn/memory/norm.py,sha256=mu_6iZJe61ag627csfJN2JK6QmmzofjOEhxV4ZWblXs,6410
9
+ rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
10
10
  rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
12
+ rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=5fl1hESVj2Hakqz5to8ZJzw5Q4_RKZAUq2bn6nRiPV8,6045
19
- rxnn/training/mrl.py,sha256=14wx3pVha15B7eRWPRgoxRtV5dPtBI0yadIHOYZjX6k,43275
18
+ rxnn/training/models.py,sha256=bY6yZoXYJEsrcymtb5Ep41vmFVHplCGWlrw1dI0oFRc,6807
19
+ rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
- rxnn/transformers/layers.py,sha256=UQZbrAg1UAttPASeqS7BP1a4JalktThmRMzX99Qghss,7618
28
+ rxnn/transformers/layers.py,sha256=LXSY829fIHSCmFmClhQ6B7I5aKbiOqy9mZmwlJG_r7U,7961
29
29
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
30
- rxnn/transformers/models.py,sha256=_2qO1SASHtKvTW3dW-Dy9HEmAvoNVC1_addm2tM9Zbs,8325
30
+ rxnn/transformers/models.py,sha256=QwVxYN9DrKllEpOiFoAx4CiThOWafeTa-OAY7L6gN0Y,8929
31
31
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.25.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.25.dist-info/METADATA,sha256=nuGFk4oqSMhn6vrw2KZs4RtY0_ZLowg29IlkNVHZ6Jo,25960
37
- rxnn-0.2.25.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.25.dist-info/RECORD,,
35
+ rxnn-0.2.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.27.dist-info/METADATA,sha256=woZT3PVGgtEJP7DIAJv1-Mdfd4XvKoCRHANQgoTXoXk,25960
37
+ rxnn-0.2.27.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.27.dist-info/RECORD,,
File without changes
File without changes