rxnn 0.2.36__py3-none-any.whl → 0.2.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +2 -1
- rxnn/training/bml.py +8 -0
- rxnn/training/rl.py +7 -8
- {rxnn-0.2.36.dist-info → rxnn-0.2.38.dist-info}/METADATA +1 -1
- {rxnn-0.2.36.dist-info → rxnn-0.2.38.dist-info}/RECORD +7 -7
- {rxnn-0.2.36.dist-info → rxnn-0.2.38.dist-info}/LICENSE +0 -0
- {rxnn-0.2.36.dist-info → rxnn-0.2.38.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -82,6 +82,7 @@ class BaseTrainer(ABC):
|
|
82
82
|
dataset: torch.utils.data.Dataset = None,
|
83
83
|
optimizer: torch.optim.Optimizer = None,
|
84
84
|
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
85
|
+
ddp_find_unused_parameters: bool = False,
|
85
86
|
) -> None:
|
86
87
|
self.is_running = True
|
87
88
|
if dataset is None:
|
@@ -94,7 +95,7 @@ class BaseTrainer(ABC):
|
|
94
95
|
if self.use_ddp:
|
95
96
|
rank, world_size = get_os_ddp_config()
|
96
97
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
97
|
-
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
98
|
+
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
|
98
99
|
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
99
100
|
dataloader = torch.utils.data.DataLoader(
|
100
101
|
dataset,
|
rxnn/training/bml.py
CHANGED
@@ -51,6 +51,10 @@ class MLMTrainer(BaseTrainer):
|
|
51
51
|
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
52
52
|
|
53
53
|
router_loss = model.encoder.model.moe_router_loss()
|
54
|
+
|
55
|
+
if self.use_ddp:
|
56
|
+
router_loss = distributed_mean(router_loss)
|
57
|
+
|
54
58
|
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
55
59
|
|
56
60
|
if self.writer is not None:
|
@@ -152,6 +156,10 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
152
156
|
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
153
157
|
|
154
158
|
router_loss = model.model.moe_router_loss()
|
159
|
+
|
160
|
+
if self.use_ddp:
|
161
|
+
router_loss = distributed_mean(router_loss)
|
162
|
+
|
155
163
|
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
156
164
|
|
157
165
|
if self.writer is not None:
|
rxnn/training/rl.py
CHANGED
@@ -112,21 +112,20 @@ class PPOAlgorithm(RlAlgorithm):
|
|
112
112
|
advantages = torch.zeros_like(rewards, device=rewards.device)
|
113
113
|
last_advantage = 0
|
114
114
|
next_value = last_value
|
115
|
-
next_done = last_done
|
115
|
+
next_done = last_done.float()
|
116
116
|
dones = dones.float()
|
117
117
|
|
118
118
|
for t in reversed(range(trajectory_len)):
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
else:
|
123
|
-
# For other steps, use the next value in the trajectory
|
124
|
-
delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
|
125
|
-
|
119
|
+
# Calculate delta from rewards, stored next_value, masked by stored next_done, and values
|
120
|
+
delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
|
121
|
+
# Calculate advantages based on delta, gamma/lambda factors and last advantage, masked by current done flags
|
126
122
|
advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
|
123
|
+
# Store current step data as last_advantage, next_done and next_value, for the next iteration step
|
127
124
|
last_advantage = advantages[t]
|
128
125
|
next_done = dones[t]
|
126
|
+
next_value = values[t]
|
129
127
|
|
128
|
+
# Calculate reference returns, based on advantages and values, and return them with advantages for critic update
|
130
129
|
returns = advantages + values
|
131
130
|
return advantages, returns
|
132
131
|
|
@@ -11,15 +11,15 @@ rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
|
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=CzFELVv5-ybAwl1s1ptpmwM7wdJ07M4jaT1-I8PYrR0,13999
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
rxnn/training/base.py,sha256=
|
15
|
-
rxnn/training/bml.py,sha256
|
14
|
+
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
|
+
rxnn/training/bml.py,sha256=-Al_qHvdjX4MR3YO-905dJHus405A2Dg-9uWmO561KU,17080
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
|
20
20
|
rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.38.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.38.dist-info/METADATA,sha256=qS0cXhbW6h6lbqGW_OZvO3x5WERF26OeafDaV7QI8dM,25960
|
38
|
+
rxnn-0.2.38.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.38.dist-info/RECORD,,
|
File without changes
|
File without changes
|