rxnn 0.2.43__tar.gz → 0.2.44__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.43 → rxnn-0.2.44}/PKG-INFO +1 -1
  2. {rxnn-0.2.43 → rxnn-0.2.44}/pyproject.toml +1 -1
  3. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/memory/stm.py +3 -0
  4. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/rxt/models.py +3 -0
  5. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/models.py +3 -0
  6. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/mrl.py +2 -0
  7. {rxnn-0.2.43 → rxnn-0.2.44}/LICENSE +0 -0
  8. {rxnn-0.2.43 → rxnn-0.2.44}/README.md +0 -0
  9. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.43 → rxnn-0.2.44}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.43
3
+ Version: 0.2.44
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.43"
7
+ version = "0.2.44"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
62
62
  def reset(self, init_type: str = None):
63
63
  self.memory = self._init_tensor(init_type).to(self.memory.device)
64
64
 
65
+ def clone_detach_reset(self):
66
+ self.memory = self.memory.detach().clone()
67
+
65
68
  def resize(self, new_stm_size: int, init_type: str = None):
66
69
  self.stm_size = new_stm_size
67
70
  delattr(self, 'memory')
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
301
301
  def reset_memory(self, init_type: str = None):
302
302
  self.model.stm.reset(init_type)
303
303
 
304
+ def clone_reset_memory(self):
305
+ self.model.stm.clone_detach_reset()
306
+
304
307
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
305
308
  return self.model(x, attention_mask=attention_mask)
306
309
 
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
146
146
  def reset_memory(self):
147
147
  self.memory_attention.reset_memory()
148
148
 
149
+ def clone_reset_memory(self):
150
+ self.memory_attention.clone_reset_memory()
151
+
149
152
  def memory_parameters(self) -> list[nn.Parameter]:
150
153
  return list(set(
151
154
  self.encoder.memory_parameters() +
@@ -646,6 +646,8 @@ class MRLTrainer:
646
646
  step_critic_values = episode_critic_values[step_idx]
647
647
  step_advantages = episode_advantages[step_idx]
648
648
 
649
+ self.actor.clone_reset_memory()
650
+
649
651
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
650
652
  if self.memory_aware_critic:
651
653
  self.encode_and_update_stm(query, answer)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes