rxnn 0.2.41__py3-none-any.whl → 0.2.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py CHANGED
@@ -1022,7 +1022,7 @@ class MRLTrainer:
1022
1022
 
1023
1023
  return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
1024
1024
 
1025
- def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
1025
+ def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int, ddp_find_unused_parameters: bool = False):
1026
1026
  """Start Memory Reinforcement Learning Curriculum."""
1027
1027
 
1028
1028
  # 0. Set global epoch count for all stages
@@ -1033,7 +1033,7 @@ class MRLTrainer:
1033
1033
  if self.use_ddp:
1034
1034
  rank, world_size = get_os_ddp_config()
1035
1035
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
1036
- self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
1036
+ self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
1037
1037
  self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
1038
1038
 
1039
1039
  # 2. Init BatchSampler with actor model (we have to run it after DDP init)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.41
3
+ Version: 0.2.42
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=4hDH-R9l1lNvBMW_CGG_QgmCVrkyG7Lyo40PPzvkovQ,8876
20
- rxnn/training/mrl.py,sha256=tv7LjW1HBXF9H7rrITQD4EmN1-qgJT44UblREzsjeew,59378
20
+ rxnn/training/mrl.py,sha256=xHH-tcmvwmwV5wwiAa3DaXLuF5OipmVDDYxLL5wOYVM,59471
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.41.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.41.dist-info/METADATA,sha256=5oKrThfhnOQK8KjDYJfcP-LTb03hNyUrSTjbOSpUUdg,25960
38
- rxnn-0.2.41.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.41.dist-info/RECORD,,
36
+ rxnn-0.2.42.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.42.dist-info/METADATA,sha256=5ZND9je7xzC5qCXQmyFB0XKedtqe5gicSqZnRui1K0Q,25960
38
+ rxnn-0.2.42.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.42.dist-info/RECORD,,
File without changes
File without changes