rxnn 0.2.42__tar.gz → 0.2.44__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.42 → rxnn-0.2.44}/PKG-INFO +1 -1
  2. {rxnn-0.2.42 → rxnn-0.2.44}/pyproject.toml +1 -1
  3. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/stm.py +3 -0
  4. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/rxt/models.py +3 -0
  5. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/models.py +6 -0
  6. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/mrl.py +8 -6
  7. {rxnn-0.2.42 → rxnn-0.2.44}/LICENSE +0 -0
  8. {rxnn-0.2.42 → rxnn-0.2.44}/README.md +0 -0
  9. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.42
3
+ Version: 0.2.44
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.42"
7
+ version = "0.2.44"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
62
62
  def reset(self, init_type: str = None):
63
63
  self.memory = self._init_tensor(init_type).to(self.memory.device)
64
64
 
65
+ def clone_detach_reset(self):
66
+ self.memory = self.memory.detach().clone()
67
+
65
68
  def resize(self, new_stm_size: int, init_type: str = None):
66
69
  self.stm_size = new_stm_size
67
70
  delattr(self, 'memory')
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
301
301
  def reset_memory(self, init_type: str = None):
302
302
  self.model.stm.reset(init_type)
303
303
 
304
+ def clone_reset_memory(self):
305
+ self.model.stm.clone_detach_reset()
306
+
304
307
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
305
308
  return self.model(x, attention_mask=attention_mask)
306
309
 
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
146
146
  def reset_memory(self):
147
147
  self.memory_attention.reset_memory()
148
148
 
149
+ def clone_reset_memory(self):
150
+ self.memory_attention.clone_reset_memory()
151
+
149
152
  def memory_parameters(self) -> list[nn.Parameter]:
150
153
  return list(set(
151
154
  self.encoder.memory_parameters() +
@@ -168,6 +171,9 @@ class MrlActorModel(nn.Module):
168
171
  self.decoder.not_memory_parameters()
169
172
  ))
170
173
 
174
+ def embedding_parameters(self) -> Iterator[nn.Parameter]:
175
+ return self.encoder.model.embedding.parameters()
176
+
171
177
  def unique_parameters(self, with_embedding: bool = True):
172
178
  if with_embedding:
173
179
  return list(set(
@@ -225,7 +225,7 @@ class MRLTrainer:
225
225
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
226
226
  if memory_lr is not None:
227
227
  optimizer = torch.optim.AdamW([
228
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
228
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
229
229
  {'params': self.actor.not_memory_parameters(), 'lr': lr},
230
230
  {'params': self.actor.memory_parameters(), 'lr': memory_lr},
231
231
  ],
@@ -233,7 +233,7 @@ class MRLTrainer:
233
233
  )
234
234
  else:
235
235
  optimizer = torch.optim.AdamW([
236
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
236
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
237
237
  {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
238
238
  ],
239
239
  weight_decay=weight_decay,
@@ -646,6 +646,8 @@ class MRLTrainer:
646
646
  step_critic_values = episode_critic_values[step_idx]
647
647
  step_advantages = episode_advantages[step_idx]
648
648
 
649
+ self.actor.clone_reset_memory()
650
+
649
651
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
650
652
  if self.memory_aware_critic:
651
653
  self.encode_and_update_stm(query, answer)
@@ -929,7 +931,7 @@ class MRLTrainer:
929
931
 
930
932
  if mode == 'update':
931
933
  params = [
932
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
934
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
933
935
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
934
936
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
935
937
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -938,7 +940,7 @@ class MRLTrainer:
938
940
  ]
939
941
  elif mode == 'fetch':
940
942
  params = [
941
- {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
943
+ {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
942
944
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
943
945
  {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
944
946
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
@@ -947,7 +949,7 @@ class MRLTrainer:
947
949
  ]
948
950
  elif mode == 'joint':
949
951
  params = [
950
- {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
952
+ {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
951
953
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
952
954
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
953
955
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -956,7 +958,7 @@ class MRLTrainer:
956
958
  ]
957
959
  else:
958
960
  params = [
959
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
961
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
960
962
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
961
963
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
962
964
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes