rxnn 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/rxt/models.py CHANGED
@@ -137,13 +137,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
137
137
  def load_shared_memory(self, stm: ShortTermMemory):
138
138
  self.model.stm = stm
139
139
 
140
- def freeze_without_memory(self):
140
+ def freeze_without_memory(self, unfreeze_norms: bool = True):
141
141
  for param in self.model.parameters():
142
142
  param.requires_grad_(False)
143
- self.model.trainable_cross_attention_(True)
143
+ self.model.trainable_cross_attention_(True, with_norms=unfreeze_norms)
144
144
 
145
- def freeze_memory(self):
146
- self.model.trainable_cross_attention_(False)
145
+ def freeze_memory(self, with_norms: bool = True):
146
+ self.model.trainable_cross_attention_(False, with_norms=with_norms)
147
147
 
148
148
  def unfreeze_all(self):
149
149
  for param in self.model.parameters():
@@ -264,6 +264,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
264
264
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
265
265
  self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
266
266
 
267
+ def freeze(self):
268
+ for param in self.parameters():
269
+ param.requires_grad = False
270
+
271
+ def unfreeze(self):
272
+ for param in self.parameters():
273
+ param.requires_grad = True
274
+
267
275
  def load_shared_memory(self, stm: ShortTermMemory):
268
276
  self.model.stm = stm
269
277
 
rxnn/training/models.py CHANGED
@@ -78,24 +78,30 @@ class MrlActorModel(nn.Module):
78
78
  def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
79
79
  """Freeze encoder/decoder except memory-related layers."""
80
80
  if self.encoder.freeze_without_memory is not None:
81
- self.encoder.freeze_without_memory()
81
+ self.encoder.freeze_without_memory(unfreeze_norms=True)
82
82
  if stage == 'update':
83
- self.encoder.freeze_memory()
83
+ self.encoder.freeze_memory(with_norms=True)
84
84
  else:
85
85
  for param in self.encoder.parameters():
86
86
  param.requires_grad = False
87
- self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
87
+ self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
88
88
  if self.decoder.freeze_without_memory is not None:
89
- self.decoder.freeze_without_memory()
89
+ self.decoder.freeze_without_memory(unfreeze_norms=True)
90
90
  if stage == 'update':
91
- self.decoder.freeze_memory()
91
+ self.decoder.freeze_memory(with_norms=True)
92
92
  else:
93
93
  for param in self.decoder.parameters():
94
94
  param.requires_grad = False
95
- self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
95
+ self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
96
96
  # Unfreeze memory attention
97
- for param in self.memory_attention.parameters():
98
- param.requires_grad = True if stage != 'fetch' else False
97
+ if self.memory_attention.freeze is not None:
98
+ if stage == 'fetch':
99
+ self.memory_attention.freeze()
100
+ else:
101
+ self.memory_attention.unfreeze()
102
+ else:
103
+ for param in self.memory_attention.parameters():
104
+ param.requires_grad = True if stage != 'fetch' else False
99
105
 
100
106
  def unfreeze_components(self):
101
107
  """Unfreeze all components after initial training."""
@@ -109,8 +115,11 @@ class MrlActorModel(nn.Module):
109
115
  else:
110
116
  for param in self.decoder.parameters():
111
117
  param.requires_grad = True
112
- for param in self.memory_attention.parameters():
113
- param.requires_grad = True
118
+ if self.memory_attention.unfreeze is not None:
119
+ self.memory_attention.unfreeze()
120
+ else:
121
+ for param in self.memory_attention.parameters():
122
+ param.requires_grad = True
114
123
 
115
124
  def reset_memory(self):
116
125
  self.memory_attention.reset_memory()
rxnn/training/mrl.py CHANGED
@@ -37,7 +37,7 @@ class CurriculumConfig(TypedDict):
37
37
  eval_dataset: Optional[MrlCurriculumDataset]
38
38
  callbacks: Optional[list[MrlTrainerCallback]]
39
39
  strategy: MrlStrategy
40
- unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
40
+ unfreeze_epoch: Optional[Union[int, tuple[int, int, int, int]]]
41
41
  random_resets: Optional[bool]
42
42
  random_resets_from: Optional[int]
43
43
  random_resets_ratio: Optional[float]
@@ -770,11 +770,7 @@ class MRLTrainer:
770
770
 
771
771
  # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
772
772
  if unfreeze_epoch != 0:
773
- is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
774
- if is_staged_unfreeze:
775
- self.actor.freeze_components('update')
776
- else:
777
- self.actor.freeze_components()
773
+ self.actor.freeze_components('both')
778
774
 
779
775
  # 5. Setup train DataLoader
780
776
  if self.use_ddp:
@@ -817,8 +813,10 @@ class MRLTrainer:
817
813
  # 11. Unfreeze all components before selected epoch
818
814
  is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
819
815
  if is_staged_unfreeze:
820
- fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
821
- if epoch == fetch_epoch:
816
+ update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
817
+ if epoch == update_epoch:
818
+ self.actor.freeze_components('update')
819
+ elif epoch == fetch_epoch:
822
820
  self.actor.freeze_components('fetch')
823
821
  elif epoch == both_epoch:
824
822
  self.actor.freeze_components('both')
@@ -57,9 +57,12 @@ class ReactiveTransformerLayer(nn.Module):
57
57
  self.use_moe = use_moe
58
58
  self.use_moe_att = use_moe_att
59
59
 
60
- def trainable_cross_attention_(self, is_trainable: bool):
60
+ def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
61
61
  for param in self.memory_cross_attention.parameters():
62
62
  param.requires_grad_(is_trainable)
63
+ if with_norms:
64
+ for param in self.norm2.parameters():
65
+ param.requires_grad_(is_trainable)
63
66
 
64
67
  def update_max_len(self, max_seq_len: int):
65
68
  if self.attention.rope is not None:
@@ -33,11 +33,11 @@ class ReactiveTransformerBase(nn.Module):
33
33
  self.num_shared_layers = len(shared_layers) if shared_layers else 0
34
34
  self.num_own_layers = len(own_layers) if own_layers else 0
35
35
 
36
- def trainable_cross_attention_(self, is_trainable: bool):
36
+ def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
37
37
  for i in range(self.num_shared_layers):
38
- self.shared_layers[i].trainable_cross_attention_(is_trainable)
38
+ self.shared_layers[i].trainable_cross_attention_(is_trainable, with_norms)
39
39
  for i in range(self.num_own_layers):
40
- self.layers[i].trainable_cross_attention_(is_trainable)
40
+ self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
41
41
 
42
42
  def moe_router_loss(self):
43
43
  return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.23
3
+ Version: 0.2.24
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -9,14 +9,14 @@ rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
10
  rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
12
+ rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=wf98gYKKm9-ZY3zwdX9NIeJ-lvh7Ro1SoAijmQxYM28,5599
19
- rxnn/training/mrl.py,sha256=7flR_ZvgSIWvY6JjbRLYJwZnD6cO55N1-zwJIf5VHQ8,43334
18
+ rxnn/training/models.py,sha256=5fl1hESVj2Hakqz5to8ZJzw5Q4_RKZAUq2bn6nRiPV8,6045
19
+ rxnn/training/mrl.py,sha256=14wx3pVha15B7eRWPRgoxRtV5dPtBI0yadIHOYZjX6k,43275
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
- rxnn/transformers/layers.py,sha256=MbOIX4PurbTbYxcXSavyFsNpTHCm26K_Ssk_VUCzKIE,7469
28
+ rxnn/transformers/layers.py,sha256=UQZbrAg1UAttPASeqS7BP1a4JalktThmRMzX99Qghss,7618
29
29
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
30
- rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
30
+ rxnn/transformers/models.py,sha256=_2qO1SASHtKvTW3dW-Dy9HEmAvoNVC1_addm2tM9Zbs,8325
31
31
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.23.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.23.dist-info/METADATA,sha256=CqyER_fR5UiM00I8t2x41ORbwzagxq8pv-Eduy2RAng,25960
37
- rxnn-0.2.23.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.23.dist-info/RECORD,,
35
+ rxnn-0.2.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.24.dist-info/METADATA,sha256=PrVfcCd8NBFtFnD8lAJqU7UW3lLEc-Tr7MQhK6obvuo,25960
37
+ rxnn-0.2.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.24.dist-info/RECORD,,
File without changes
File without changes