rxnn 0.2.37__py3-none-any.whl → 0.2.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -82,6 +82,7 @@ class BaseTrainer(ABC):
82
82
  dataset: torch.utils.data.Dataset = None,
83
83
  optimizer: torch.optim.Optimizer = None,
84
84
  scheduler: torch.optim.lr_scheduler.LRScheduler = None,
85
+ ddp_find_unused_parameters: bool = False,
85
86
  ) -> None:
86
87
  self.is_running = True
87
88
  if dataset is None:
@@ -94,7 +95,7 @@ class BaseTrainer(ABC):
94
95
  if self.use_ddp:
95
96
  rank, world_size = get_os_ddp_config()
96
97
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
97
- self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
98
+ self.model = DistributedDataParallel(self.model, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
98
99
  train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
99
100
  dataloader = torch.utils.data.DataLoader(
100
101
  dataset,
rxnn/training/bml.py CHANGED
@@ -51,6 +51,10 @@ class MLMTrainer(BaseTrainer):
51
51
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
52
52
 
53
53
  router_loss = model.encoder.model.moe_router_loss()
54
+
55
+ if self.use_ddp:
56
+ router_loss = distributed_mean(router_loss)
57
+
54
58
  loss = main_loss + self.moe_aux_loss_scale * router_loss
55
59
 
56
60
  if self.writer is not None:
@@ -152,6 +156,10 @@ class AutoregressiveTrainer(BaseTrainer):
152
156
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
153
157
 
154
158
  router_loss = model.model.moe_router_loss()
159
+
160
+ if self.use_ddp:
161
+ router_loss = distributed_mean(router_loss)
162
+
155
163
  loss = main_loss + self.moe_aux_loss_scale * router_loss
156
164
 
157
165
  if self.writer is not None:
rxnn/training/rl.py CHANGED
@@ -116,17 +116,16 @@ class PPOAlgorithm(RlAlgorithm):
116
116
  dones = dones.float()
117
117
 
118
118
  for t in reversed(range(trajectory_len)):
119
- if t == trajectory_len - 1:
120
- # For the last step, use the provided last_value
121
- delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
122
- else:
123
- # For other steps, use the next value in the trajectory
124
- delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
125
-
119
+ # Calculate delta from rewards, stored next_value, masked by stored next_done, and values
120
+ delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
121
+ # Calculate advantages based on delta, gamma/lambda factors and last advantage, masked by current done flags
126
122
  advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
123
+ # Store current step data as last_advantage, next_done and next_value, for the next iteration step
127
124
  last_advantage = advantages[t]
128
125
  next_done = dones[t]
126
+ next_value = values[t]
129
127
 
128
+ # Calculate reference returns, based on advantages and values, and return them with advantages for critic update
130
129
  returns = advantages + values
131
130
  return advantages, returns
132
131
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.37
3
+ Version: 0.2.38
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -11,15 +11,15 @@ rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=CzFELVv5-ybAwl1s1ptpmwM7wdJ07M4jaT1-I8PYrR0,13999
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- rxnn/training/base.py,sha256=TGz_37RfI1qLI31GNRV5rLowW1kAHnJwqPm7DNfLfe4,11730
15
- rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
14
+ rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
+ rxnn/training/bml.py,sha256=-Al_qHvdjX4MR3YO-905dJHus405A2Dg-9uWmO561KU,17080
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
20
  rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
- rxnn/training/rl.py,sha256=47wxFeUSHSqc1dKEEy8skTcNHDqNuthsYTGA-HeUbhg,5982
22
+ rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.37.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.37.dist-info/METADATA,sha256=GXBCyK-3ALJw6TpVk7rJ7Z2uyFSq8u8N-TYpmQaeUE8,25960
38
- rxnn-0.2.37.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.37.dist-info/RECORD,,
36
+ rxnn-0.2.38.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.38.dist-info/METADATA,sha256=qS0cXhbW6h6lbqGW_OZvO3x5WERF26OeafDaV7QI8dM,25960
38
+ rxnn-0.2.38.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.38.dist-info/RECORD,,
File without changes
File without changes