rxnn 0.2.12__tar.gz → 0.2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {rxnn-0.2.12 → rxnn-0.2.13}/PKG-INFO +1 -1
  2. {rxnn-0.2.12 → rxnn-0.2.13}/pyproject.toml +1 -1
  3. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/memory/stm.py +3 -3
  4. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/mrl.py +2 -3
  5. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/rl.py +2 -0
  6. {rxnn-0.2.12 → rxnn-0.2.13}/LICENSE +0 -0
  7. {rxnn-0.2.12 → rxnn-0.2.13}/README.md +0 -0
  8. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/rxt/__init__.py +0 -0
  18. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/rxt/models.py +0 -0
  19. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/scheduler.py +0 -0
  27. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/tokenizer.py +0 -0
  28. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/training/utils.py +0 -0
  29. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/__init__.py +0 -0
  30. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/attention.py +0 -0
  31. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/ff.py +0 -0
  32. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/layers.py +0 -0
  33. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/mask.py +0 -0
  34. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/models.py +0 -0
  35. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/moe.py +0 -0
  36. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/positional.py +0 -0
  37. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/transformers/sampler.py +0 -0
  38. {rxnn-0.2.12 → rxnn-0.2.13}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.12"
7
+ version = "0.2.13"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -39,6 +39,7 @@ class ShortTermMemory(nn.Module):
39
39
  return self.memory[layer]
40
40
 
41
41
  def update_layer(self, layer: int, new_stm: torch.Tensor):
42
+ self.memory = self.memory.clone()
42
43
  self.memory[layer] = new_stm
43
44
 
44
45
  def update_all(self, new_stm: torch.Tensor):
@@ -60,7 +61,7 @@ class ShortTermMemory(nn.Module):
60
61
  self.register_buffer('memory', trained_stm)
61
62
 
62
63
  def reset(self, init_type: str = None):
63
- self.memory.copy_(self._init_tensor(init_type))
64
+ self.memory = self._init_tensor(init_type).to(self.memory.device)
64
65
 
65
66
  def resize(self, new_stm_size: int, init_type: str = None):
66
67
  self.stm_size = new_stm_size
@@ -85,8 +86,7 @@ class ShortTermMemory(nn.Module):
85
86
  if use_mean_from_batch:
86
87
  batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
87
88
  delattr(self, 'memory')
88
- self.register_buffer('memory', self._init_tensor())
89
- self.memory.copy_(batch_mean)
89
+ self.register_buffer('memory', batch_mean)
90
90
  else:
91
91
  delattr(self, 'memory')
92
92
  self.register_buffer('memory', self._init_tensor())
@@ -461,7 +461,6 @@ class MRLTrainer:
461
461
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
462
462
  # memory, based on collected episode data
463
463
  all_losses = []
464
- trajectories_len = len(trajectories)
465
464
  for episode_idx, episode in enumerate(trajectories):
466
465
  episode_steps = episode['steps']
467
466
  should_reset_stm = episode['reset_stm']
@@ -514,14 +513,14 @@ class MRLTrainer:
514
513
 
515
514
  # 9. Update the model in AMP or regular mode
516
515
  if self.use_amp:
517
- self.scaler.scale(policy_loss).backward()
516
+ self.scaler.scale(policy_loss).backward(retain_graph=True)
518
517
  self.scaler.unscale_(self.optimizer)
519
518
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
520
519
  error_if_nonfinite=False)
521
520
  self.scaler.step(self.optimizer)
522
521
  self.scaler.update()
523
522
  else:
524
- policy_loss.backward()
523
+ policy_loss.backward(retain_graph=True)
525
524
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
526
525
  error_if_nonfinite=False)
527
526
  self.optimizer.step()
@@ -43,6 +43,8 @@ class PPOAlgorithm(RlAlgorithm):
43
43
  # b) Calculate ratio
44
44
  ratio = (new_log_probs - old_log_probs).exp()
45
45
 
46
+ advantages = advantages.unsqueeze(-1)
47
+
46
48
  # c) Clipped surrogate loss
47
49
  surr1 = ratio * advantages
48
50
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes