rxnn 0.2.37__py3-none-any.whl → 0.2.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py
CHANGED
@@ -82,6 +82,7 @@ class BaseTrainer(ABC):
|
|
82
82
|
dataset: torch.utils.data.Dataset = None,
|
83
83
|
optimizer: torch.optim.Optimizer = None,
|
84
84
|
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
85
|
+
ddp_find_unused_parameters: bool = False,
|
85
86
|
) -> None:
|
86
87
|
self.is_running = True
|
87
88
|
if dataset is None:
|
@@ -94,7 +95,7 @@ class BaseTrainer(ABC):
|
|
94
95
|
if self.use_ddp:
|
95
96
|
rank, world_size = get_os_ddp_config()
|
96
97
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
97
|
-
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
98
|
+
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
|
98
99
|
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
99
100
|
dataloader = torch.utils.data.DataLoader(
|
100
101
|
dataset,
|
rxnn/training/rl.py
CHANGED
@@ -116,17 +116,16 @@ class PPOAlgorithm(RlAlgorithm):
|
|
116
116
|
dones = dones.float()
|
117
117
|
|
118
118
|
for t in reversed(range(trajectory_len)):
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
else:
|
123
|
-
# For other steps, use the next value in the trajectory
|
124
|
-
delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
|
125
|
-
|
119
|
+
# Calculate delta from rewards, stored next_value, masked by stored next_done, and values
|
120
|
+
delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
|
121
|
+
# Calculate advantages based on delta, gamma/lambda factors and last advantage, masked by current done flags
|
126
122
|
advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
|
123
|
+
# Store current step data as last_advantage, next_done and next_value, for the next iteration step
|
127
124
|
last_advantage = advantages[t]
|
128
125
|
next_done = dones[t]
|
126
|
+
next_value = values[t]
|
129
127
|
|
128
|
+
# Calculate reference returns, based on advantages and values, and return them with advantages for critic update
|
130
129
|
returns = advantages + values
|
131
130
|
return advantages, returns
|
132
131
|
|
@@ -11,7 +11,7 @@ rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
|
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=CzFELVv5-ybAwl1s1ptpmwM7wdJ07M4jaT1-I8PYrR0,13999
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
rxnn/training/base.py,sha256=
|
14
|
+
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
@@ -19,7 +19,7 @@ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
|
19
19
|
rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
|
20
20
|
rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.39.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.39.dist-info/METADATA,sha256=0Ky_SOITUSAzWBAcLtNl6Wq2n6ESnMNEs6_sBKezQ88,25960
|
38
|
+
rxnn-0.2.39.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.39.dist-info/RECORD,,
|
File without changes
|
File without changes
|