rxnn 0.2.49__tar.gz → 0.2.50__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.49 → rxnn-0.2.50}/PKG-INFO +1 -1
  2. {rxnn-0.2.49 → rxnn-0.2.50}/pyproject.toml +1 -1
  3. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/reward.py +2 -2
  4. {rxnn-0.2.49 → rxnn-0.2.50}/LICENSE +0 -0
  5. {rxnn-0.2.49 → rxnn-0.2.50}/README.md +0 -0
  6. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/experimental/attention.py +0 -0
  10. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/experimental/models.py +0 -0
  11. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/memory/attention.py +0 -0
  14. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.49 → rxnn-0.2.50}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.49
3
+ Version: 0.2.50
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.49"
7
+ version = "0.2.50"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -242,8 +242,8 @@ class MrlRewardModel:
242
242
  return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
243
243
 
244
244
  def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
245
- target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
246
- lens = generated['attention_mask'].sum(dim=1)
245
+ target_lens = reference['attention_mask'].to(self.device).sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
246
+ lens = generated['attention_mask'].to(self.device).sum(dim=1)
247
247
  neg_lens = target_lens / lens if self.neg_reward_len else 1.0
248
248
  len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
249
249
  return len_reward
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes