rxnn 0.2.62__tar.gz → 0.2.63__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.62 → rxnn-0.2.63}/PKG-INFO +1 -1
  2. {rxnn-0.2.62 → rxnn-0.2.63}/pyproject.toml +1 -1
  3. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/mrl.py +14 -3
  4. {rxnn-0.2.62 → rxnn-0.2.63}/LICENSE +0 -0
  5. {rxnn-0.2.62 → rxnn-0.2.63}/README.md +0 -0
  6. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/experimental/attention.py +0 -0
  10. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/experimental/models.py +0 -0
  11. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/memory/attention.py +0 -0
  14. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.62 → rxnn-0.2.63}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.62
3
+ Version: 0.2.63
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.62"
7
+ version = "0.2.63"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -243,20 +243,31 @@ class MRLTrainer:
243
243
  critic_weight_decay: float,
244
244
  critic_encoder_lr: float,
245
245
  embedding_lr: float,
246
+ encoder_lr: float,
246
247
  memory_lr: Optional[float] = None,
248
+ encoder_memory_lr: Optional[float] = None,
249
+ memory_attn_lr: Optional[float] = None,
247
250
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
248
251
  if memory_lr is not None:
249
252
  optimizer = torch.optim.AdamW([
250
253
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
251
- {'params': self.actor.not_memory_parameters(), 'lr': lr},
252
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
254
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
255
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
256
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
257
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
258
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
253
259
  ],
254
260
  weight_decay=weight_decay,
255
261
  )
256
262
  else:
257
263
  optimizer = torch.optim.AdamW([
258
264
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
259
- {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
265
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
266
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
267
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_lr},
268
+ {'params': self.actor.memory_attention_parameters(), 'lr': lr},
269
+ {'params': self.actor.decoder.memory_parameters(), 'lr': lr},
270
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
260
271
  ],
261
272
  weight_decay=weight_decay,
262
273
  )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes