rxnn 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +4 -5
- rxnn/training/bml.py +7 -12
- rxnn/training/callbacks.py +12 -0
- rxnn/training/ddp.py +26 -0
- rxnn/training/mrl.py +374 -272
- rxnn/training/rl.py +49 -9
- {rxnn-0.2.29.dist-info → rxnn-0.2.31.dist-info}/METADATA +1 -1
- {rxnn-0.2.29.dist-info → rxnn-0.2.31.dist-info}/RECORD +10 -9
- {rxnn-0.2.29.dist-info → rxnn-0.2.31.dist-info}/LICENSE +0 -0
- {rxnn-0.2.29.dist-info → rxnn-0.2.31.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -8,6 +8,7 @@ import torch.distributed as dist
|
|
8
8
|
from torch.nn.parallel import DistributedDataParallel
|
9
9
|
from typing import Callable
|
10
10
|
from .callbacks import TrainerCallback
|
11
|
+
from .ddp import get_os_ddp_config, distributed_value_mean
|
11
12
|
|
12
13
|
|
13
14
|
class BaseTrainer(ABC):
|
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
|
|
91
92
|
optimizer = self.optimizer
|
92
93
|
|
93
94
|
if self.use_ddp:
|
94
|
-
rank =
|
95
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
95
|
+
rank, world_size = get_os_ddp_config()
|
96
96
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
97
97
|
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
98
98
|
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
|
|
218
218
|
if self.validation_dataset:
|
219
219
|
self.validation_steps = 0
|
220
220
|
val_loss, val_metrics = self.validate(batch_size)
|
221
|
-
val_loss_tensor = torch.tensor(val_loss).to(self.device)
|
222
221
|
if self.use_ddp:
|
223
|
-
|
224
|
-
|
222
|
+
val_loss = distributed_value_mean(val_loss, device=self.device)
|
223
|
+
|
225
224
|
self.validation_metrics[epoch] = val_metrics
|
226
225
|
|
227
226
|
if self.writer:
|
rxnn/training/bml.py
CHANGED
@@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
7
7
|
from ..transformers.models import ReactiveTransformerDecoder
|
8
8
|
from ..training.base import BaseTrainer
|
9
9
|
from .models import MLMTrainingModel, JointTrainingModel
|
10
|
+
from .ddp import distributed_mean
|
10
11
|
|
11
12
|
class MLMTrainer(BaseTrainer):
|
12
13
|
def __init__(
|
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
|
|
96
97
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
97
98
|
node_acc = acc.item()
|
98
99
|
if self.use_ddp:
|
99
|
-
|
100
|
-
acc = acc / dist.get_world_size()
|
100
|
+
acc = distributed_mean(acc)
|
101
101
|
|
102
102
|
metrics = {
|
103
103
|
'accuracy': acc.item(),
|
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
198
198
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
199
199
|
node_acc = acc.item()
|
200
200
|
if self.use_ddp:
|
201
|
-
|
202
|
-
acc = acc / dist.get_world_size()
|
201
|
+
acc = distributed_mean(acc)
|
203
202
|
|
204
203
|
metrics = {
|
205
204
|
'accuracy': acc.item(),
|
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
|
|
347
346
|
node_mlm_acc = mlm_acc.item()
|
348
347
|
node_alm_acc = alm_acc.item()
|
349
348
|
if self.use_ddp:
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
avg_dec_loss = avg_dec_loss / dist.get_world_size()
|
355
|
-
avg_enc_loss = avg_enc_loss / dist.get_world_size()
|
356
|
-
mlm_acc = mlm_acc / dist.get_world_size()
|
357
|
-
alm_acc = alm_acc / dist.get_world_size()
|
349
|
+
avg_dec_loss = distributed_mean(avg_dec_loss)
|
350
|
+
avg_enc_loss = distributed_mean(avg_enc_loss)
|
351
|
+
mlm_acc = distributed_mean(mlm_acc)
|
352
|
+
alm_acc = distributed_mean(alm_acc)
|
358
353
|
|
359
354
|
metrics = {
|
360
355
|
'accuracy': {
|
rxnn/training/callbacks.py
CHANGED
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
|
|
536
536
|
def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
537
|
pass
|
538
538
|
|
539
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
540
|
+
pass
|
541
|
+
|
539
542
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
540
543
|
pass
|
541
544
|
|
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
|
|
543
546
|
critic_loss: float) -> None:
|
544
547
|
pass
|
545
548
|
|
549
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
550
|
+
pass
|
551
|
+
|
546
552
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
547
553
|
pass
|
548
554
|
|
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
572
578
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
573
579
|
print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
|
574
580
|
|
581
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
582
|
+
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
|
+
|
575
584
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
576
585
|
print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
|
577
586
|
|
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
579
588
|
critic_loss: float) -> None:
|
580
589
|
print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
|
581
590
|
|
591
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
|
+
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
593
|
+
|
582
594
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
595
|
print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
584
596
|
|
rxnn/training/ddp.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.distributed as dist
|
3
|
+
import os
|
4
|
+
from ..utils import set_random_seed
|
5
|
+
|
6
|
+
def get_os_ddp_config():
|
7
|
+
rank = int(os.environ['RANK'])
|
8
|
+
world_size = int(os.environ['WORLD_SIZE'])
|
9
|
+
return rank, world_size
|
10
|
+
|
11
|
+
def distributed_mean(x: torch.Tensor) -> torch.Tensor:
|
12
|
+
"""Average tensor across all devices"""
|
13
|
+
x = x.clone()
|
14
|
+
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
15
|
+
x /= dist.get_world_size()
|
16
|
+
return x
|
17
|
+
|
18
|
+
def distributed_value_mean(value: float, device: torch.device = None) -> float:
|
19
|
+
"""Average float value across all devices"""
|
20
|
+
tensor = torch.tensor(value, device=device)
|
21
|
+
reduced = distributed_mean(tensor)
|
22
|
+
return reduced.item()
|
23
|
+
|
24
|
+
def set_distributed_random_seed(seed: int):
|
25
|
+
rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
|
26
|
+
set_random_seed(seed + rank)
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict, Union, TypeAlias, Literal
|
6
|
+
from typing import Optional, TypedDict, Union, TypeAlias, Literal, Callable
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -13,7 +13,7 @@ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
|
|
13
13
|
from .rl import RlAlgorithm
|
14
14
|
from .reward import MrlRewardMode, MrlRewardModel
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
16
|
-
|
16
|
+
from .ddp import get_os_ddp_config, distributed_mean
|
17
17
|
|
18
18
|
class MrlConfig(TypedDict):
|
19
19
|
lr: float
|
@@ -24,6 +24,11 @@ class MrlConfig(TypedDict):
|
|
24
24
|
critic_max_len: int
|
25
25
|
weight_decay: float
|
26
26
|
critic_weight_decay: float
|
27
|
+
update_epochs: int
|
28
|
+
pad_token_id: int
|
29
|
+
end_token_id: int
|
30
|
+
callbacks: Optional[list[MrlTrainerCallback]]
|
31
|
+
memory_aware_critic: bool
|
27
32
|
|
28
33
|
|
29
34
|
class MrlStrategy(Enum):
|
@@ -31,8 +36,11 @@ class MrlStrategy(Enum):
|
|
31
36
|
MULTI_STEP_STRATEGY = 2
|
32
37
|
LONG_RANGE_STRATEGY = 3
|
33
38
|
|
39
|
+
|
40
|
+
UnfreezeStrategyFn = Callable[[int], None]
|
34
41
|
UnfreezeItem = Union[int, tuple[int, float]]
|
35
|
-
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
|
42
|
+
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int], UnfreezeStrategyFn]
|
43
|
+
|
36
44
|
|
37
45
|
class CurriculumConfig(TypedDict):
|
38
46
|
steps: int
|
@@ -52,6 +60,7 @@ class CurriculumConfig(TypedDict):
|
|
52
60
|
critic_lr: Optional[float]
|
53
61
|
weight_decay: Optional[float]
|
54
62
|
critic_weight_decay: Optional[float]
|
63
|
+
update_epochs: Optional[int]
|
55
64
|
|
56
65
|
|
57
66
|
class SamplerConfig(TypedDict):
|
@@ -66,6 +75,7 @@ class MrlTrajectoryStep(TypedDict):
|
|
66
75
|
log_probs: torch.Tensor
|
67
76
|
reward: list[float]
|
68
77
|
reference: TokenizedDict
|
78
|
+
done: bool
|
69
79
|
|
70
80
|
|
71
81
|
class MrlTrajectoryEpisode(TypedDict):
|
@@ -84,21 +94,25 @@ class MRLTrainer:
|
|
84
94
|
rl_algorithm: RlAlgorithm,
|
85
95
|
sampler_config: Optional[SamplerConfig] = None,
|
86
96
|
log_dir: str = None,
|
87
|
-
pad_token_id: int = 0,
|
88
|
-
end_token_id: int = 3,
|
89
97
|
use_ddp: bool = False,
|
90
98
|
use_amp: bool = False,
|
91
99
|
dtype: torch.dtype = torch.float32,
|
92
|
-
callbacks: list[MrlTrainerCallback] = None,
|
93
|
-
|
94
100
|
):
|
95
101
|
"""
|
96
|
-
Trainer for Memory Reinforcement Learning (MRL)
|
102
|
+
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
97
103
|
|
98
104
|
Args:
|
99
|
-
actor: MRL Actor model with encoder, decoder and memory attention.
|
100
|
-
critic: Critic network for advantage estimation.
|
101
|
-
|
105
|
+
actor (MrlActorModel): MRL Actor model with encoder, decoder and memory attention.
|
106
|
+
critic (MrlCriticModel): MRL Critic network for advantage estimation.
|
107
|
+
reward (MrlRewardModel): MRL Reward model or extension.
|
108
|
+
device (torch.device): Device used for training.
|
109
|
+
config (MrlConfig): Configuration dictionary with hyperparameters.
|
110
|
+
rl_algorithm (RlAlgorithm): Reinforcement Learning algorithm (currently only PPO available).
|
111
|
+
sampler_config (SamplerConfig): Sampler configuration.
|
112
|
+
log_dir (str): Log directory for TensorBoard logs.
|
113
|
+
use_ddp (bool): Use Distributed Data Parallel mode.
|
114
|
+
use_amp (bool): Use AMP Autocast for training.
|
115
|
+
dtype (torch.dtype): Data type used in training - in AMP mode it's auto cast, otherwise data and model are transformed to this type
|
102
116
|
"""
|
103
117
|
self.actor = actor
|
104
118
|
self.critic = critic
|
@@ -107,6 +121,10 @@ class MRLTrainer:
|
|
107
121
|
self.device = device
|
108
122
|
self.max_seq_len = config.get('max_seq_len', 256)
|
109
123
|
self.critic_max_len = config.get('critic_max_len', 512)
|
124
|
+
self.memory_aware_critic = config.get('memory_aware_critic', False)
|
125
|
+
# Internal update epochs config
|
126
|
+
self.shared_update_epochs = config.get('update_epochs', 10)
|
127
|
+
self.update_epochs = self.shared_update_epochs
|
110
128
|
|
111
129
|
# Move models to device
|
112
130
|
if use_amp:
|
@@ -117,14 +135,15 @@ class MRLTrainer:
|
|
117
135
|
self.critic.to(self.device, dtype=dtype)
|
118
136
|
|
119
137
|
# Batch Sampler for answer generation
|
120
|
-
self.generator =
|
138
|
+
self.generator = None
|
121
139
|
self.sampler_config = SamplerConfig(
|
122
140
|
temperature=1.0,
|
123
141
|
top_k=None,
|
124
142
|
top_p=None,
|
125
143
|
) if sampler_config is None else sampler_config
|
126
144
|
|
127
|
-
self.pad_token_id = pad_token_id
|
145
|
+
self.pad_token_id = config.get('pad_token_id', 0)
|
146
|
+
self.end_token_id = config.get('end_token_id', 3)
|
128
147
|
|
129
148
|
self.use_ddp = use_ddp
|
130
149
|
self.use_amp = use_amp
|
@@ -172,7 +191,7 @@ class MRLTrainer:
|
|
172
191
|
self.eval_dataset = None
|
173
192
|
self.random_resets_ratio = 0.0
|
174
193
|
self.strategy = None
|
175
|
-
self.shared_callbacks = callbacks
|
194
|
+
self.shared_callbacks = config.get('callbacks', [])
|
176
195
|
self.callbacks = []
|
177
196
|
self.global_epoch = 0
|
178
197
|
self.global_epochs_count = 0
|
@@ -187,8 +206,8 @@ class MRLTrainer:
|
|
187
206
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
188
207
|
if memory_lr is not None:
|
189
208
|
optimizer = torch.optim.AdamW([
|
190
|
-
{
|
191
|
-
{
|
209
|
+
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
210
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
192
211
|
],
|
193
212
|
weight_decay=weight_decay,
|
194
213
|
)
|
@@ -207,12 +226,10 @@ class MRLTrainer:
|
|
207
226
|
|
208
227
|
return optimizer, critic_optimizer
|
209
228
|
|
210
|
-
|
211
229
|
def _init_steps(self):
|
212
230
|
return {
|
213
231
|
'collect': 0,
|
214
|
-
'
|
215
|
-
'rl': 0,
|
232
|
+
'train': 0,
|
216
233
|
'eval': 0,
|
217
234
|
}
|
218
235
|
|
@@ -221,9 +238,12 @@ class MRLTrainer:
|
|
221
238
|
self.epoch_step[step_type] += 1
|
222
239
|
self.stage_step[step_type] += 1
|
223
240
|
|
224
|
-
def reset_stm(self) -> bool:
|
241
|
+
def reset_stm(self, force: bool = False) -> bool:
|
225
242
|
"""Reset Short-Term Memory state with random reset ratio."""
|
226
|
-
if
|
243
|
+
if force:
|
244
|
+
self.actor.reset_memory()
|
245
|
+
return True
|
246
|
+
elif self.random_resets_ratio == 1.0:
|
227
247
|
self.actor.reset_memory()
|
228
248
|
return True
|
229
249
|
else:
|
@@ -351,7 +371,7 @@ class MRLTrainer:
|
|
351
371
|
# state from existing one, instead of new random one)
|
352
372
|
reset_done = self.reset_stm()
|
353
373
|
|
354
|
-
# 4. Reset reward prev data running mean - it's calculated for
|
374
|
+
# 4. Reset reward prev data running mean - it's calculated for multistep retention, we have to reset it before episode
|
355
375
|
self.reward.reset_running_mean()
|
356
376
|
|
357
377
|
# 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
@@ -406,6 +426,7 @@ class MRLTrainer:
|
|
406
426
|
'log_probs': log_probs.detach().cpu(),
|
407
427
|
'reward': reward,
|
408
428
|
'reference': interaction['answer'],
|
429
|
+
'done': is_last_interaction,
|
409
430
|
}
|
410
431
|
episode_steps.append(trajectory)
|
411
432
|
episode_rewards.append(reward)
|
@@ -432,201 +453,268 @@ class MRLTrainer:
|
|
432
453
|
|
433
454
|
return trajectories
|
434
455
|
|
435
|
-
def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
436
|
-
# 1. Calculate values with critic encoder
|
437
|
-
values = self.critic(
|
438
|
-
inputs['input_ids'],
|
439
|
-
attention_mask=inputs['attention_mask'],
|
440
|
-
).squeeze()
|
441
|
-
# 2. Calculate critic loss
|
442
|
-
loss = self.rl_algorithm.critic_loss(values, rewards)
|
443
|
-
return loss
|
444
|
-
|
445
456
|
def _critic_writer(self, critic_loss: float, epoch: int):
|
446
457
|
if self.writer is not None:
|
447
|
-
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['
|
458
|
+
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['train'])
|
448
459
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
|
449
|
-
self.epoch_step['
|
460
|
+
self.epoch_step['train'])
|
450
461
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
|
451
|
-
self.stage_step['
|
452
|
-
|
453
|
-
def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
|
454
|
-
rewards: list[torch.Tensor], epoch: int):
|
455
|
-
"""Update critic network using MSE loss."""
|
456
|
-
# 1. Run critic updates for all collected batches
|
457
|
-
critic_losses = []
|
458
|
-
for step_idx, (state, reward) in enumerate(zip(states, rewards)):
|
459
|
-
self._increment_steps('critic')
|
460
|
-
# 2. Move state batches to training device (GPU)
|
461
|
-
prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
|
462
|
-
|
463
|
-
# 3. Reset critic gradients
|
464
|
-
self.critic_optimizer.zero_grad()
|
465
|
-
|
466
|
-
# 4. Run critic and calculate loss - in autocast on/off mode
|
467
|
-
if self.use_amp:
|
468
|
-
# Move tensors to training device and calculate loss in autocast mode
|
469
|
-
batch_rewards = reward.to(self.device)
|
470
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
471
|
-
# Concatenate state into single critic input sequence
|
472
|
-
inputs = smart_concat_critic_states(
|
473
|
-
prev_query, prev_answer, next_query,
|
474
|
-
max_length=self.critic_max_len,
|
475
|
-
pad_token_id=self.pad_token_id,
|
476
|
-
)
|
477
|
-
loss = self._critic_loss(inputs, batch_rewards)
|
478
|
-
# Run backpropagation with scaler
|
479
|
-
self.critic_scaler.scale(loss).backward()
|
480
|
-
# Unscale and clip gradients
|
481
|
-
self.critic_scaler.unscale_(self.critic_optimizer)
|
482
|
-
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
483
|
-
# Run scaled optimization step
|
484
|
-
self.critic_scaler.step(self.critic_optimizer)
|
485
|
-
self.critic_scaler.update()
|
486
|
-
else:
|
487
|
-
# Concatenate state into single critic input sequence
|
488
|
-
inputs = smart_concat_critic_states(
|
489
|
-
prev_query, prev_answer, next_query,
|
490
|
-
max_length=self.critic_max_len,
|
491
|
-
pad_token_id=self.pad_token_id,
|
492
|
-
)
|
493
|
-
# Calculate loss
|
494
|
-
loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
|
495
|
-
# Run backpropagation
|
496
|
-
loss.backward()
|
497
|
-
# Clip gradients
|
498
|
-
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
499
|
-
# Run optimizer step
|
500
|
-
self.critic_optimizer.step()
|
501
|
-
critic_loss = loss.item()
|
502
|
-
self._critic_writer(critic_loss, epoch)
|
503
|
-
|
504
|
-
# 5. Run "on critic updated" callbacks
|
505
|
-
for cb in self.callbacks:
|
506
|
-
cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
|
507
|
-
|
508
|
-
# 6. Accumulate loss for epoch callbacks
|
509
|
-
critic_losses.append(critic_loss)
|
510
|
-
|
511
|
-
# 7. Calculate mean loss for epoch callbacks
|
512
|
-
critic_mean_loss = torch.tensor(critic_losses).mean().item()
|
513
|
-
|
514
|
-
return critic_mean_loss
|
515
|
-
|
516
|
-
def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
517
|
-
with torch.no_grad():
|
518
|
-
values = self.critic(critic_state['input_ids'],
|
519
|
-
attention_mask=critic_state['attention_mask']).squeeze()
|
520
|
-
return self.rl_algorithm.calculate_advantages(rewards, values)
|
462
|
+
self.stage_step['train'])
|
521
463
|
|
522
464
|
def _rl_writer(self, policy_loss: float, epoch: int):
|
523
465
|
if self.writer is not None:
|
524
|
-
self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['
|
466
|
+
self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['train'])
|
525
467
|
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
|
526
|
-
self.epoch_step['
|
527
|
-
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
|
468
|
+
self.epoch_step['train'])
|
469
|
+
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
|
470
|
+
self.stage_step['train'])
|
528
471
|
|
529
|
-
def
|
472
|
+
def update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
|
473
|
+
epoch: int) -> float:
|
474
|
+
# 1. Reset critic gradients
|
475
|
+
self.critic_optimizer.zero_grad()
|
476
|
+
|
477
|
+
# 2. Update critic - with autocast on/off
|
478
|
+
if self.use_amp:
|
479
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
480
|
+
# 2.1 Concat states and calculate critic loss
|
481
|
+
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
482
|
+
pad_token_id=self.pad_token_id)
|
483
|
+
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
484
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
485
|
+
# 2.2 Run backpropagation with scaler
|
486
|
+
self.critic_scaler.scale(critic_loss).backward()
|
487
|
+
# 2.3 Unscale and clip gradients
|
488
|
+
self.critic_scaler.unscale_(self.critic_optimizer)
|
489
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
490
|
+
# 2.4 Run scaled optimization step
|
491
|
+
self.critic_scaler.step(self.critic_optimizer)
|
492
|
+
self.critic_scaler.update()
|
493
|
+
else:
|
494
|
+
# 2.1 Concat states and calculate critic loss
|
495
|
+
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
496
|
+
pad_token_id=self.pad_token_id)
|
497
|
+
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
498
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
499
|
+
# 2.2 Run backpropagation
|
500
|
+
critic_loss.backward()
|
501
|
+
# 2.3 Clip gradients
|
502
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
503
|
+
# 2.4 Run optimizer step
|
504
|
+
self.critic_optimizer.step()
|
505
|
+
# 3. Get float loss value for callbacks/writer
|
506
|
+
critic_loss_item = critic_loss.item()
|
507
|
+
|
508
|
+
# 4. Write to TensorBoard
|
509
|
+
self._critic_writer(critic_loss_item, epoch)
|
510
|
+
|
511
|
+
# 5. Run "on critic updated" callbacks
|
512
|
+
for cb in self.callbacks:
|
513
|
+
cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['train'], critic_loss_item)
|
514
|
+
# 6. Return loss item
|
515
|
+
return critic_loss_item
|
516
|
+
|
517
|
+
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
518
|
+
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
519
|
+
# 1. Reset actor gradients
|
520
|
+
self.optimizer.zero_grad()
|
521
|
+
# 2. Unpack state dicts
|
522
|
+
query, answer, next_query = state
|
523
|
+
|
524
|
+
# 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss (skip if it was updated before with memory aware critic)
|
525
|
+
if not self.memory_aware_critic:
|
526
|
+
self.encode_and_update_stm(query, answer)
|
527
|
+
|
528
|
+
# 4. Update actor - with autocast on/off
|
529
|
+
if self.use_amp:
|
530
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
531
|
+
# 4.1 Concatenate next query and action and get action logits from decoder
|
532
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
533
|
+
pad_token_id=self.pad_token_id)
|
534
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
535
|
+
action=MrlActorAction.DECODE)
|
536
|
+
# 4.2 Calculate policy loss with selected algorithm
|
537
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
|
538
|
+
advantages)
|
539
|
+
# 4.3 Run backpropagation with scaler
|
540
|
+
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
541
|
+
# 4.4 Unscale and clip gradient norms
|
542
|
+
self.scaler.unscale_(self.optimizer)
|
543
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
544
|
+
error_if_nonfinite=False)
|
545
|
+
# 4.5 Run scaled optimization step
|
546
|
+
self.scaler.step(self.optimizer)
|
547
|
+
self.scaler.update()
|
548
|
+
else:
|
549
|
+
# 4.1 Concatenate next query and action and get action logits from decoder
|
550
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
551
|
+
pad_token_id=self.pad_token_id)
|
552
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
553
|
+
action=MrlActorAction.DECODE)
|
554
|
+
# 4.2 Calculate policy loss with selected algorithm
|
555
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
|
556
|
+
# 4.3 Run backpropagation
|
557
|
+
policy_loss.backward(retain_graph=True)
|
558
|
+
# 4.4 Clip gradient norms
|
559
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
560
|
+
error_if_nonfinite=False)
|
561
|
+
# 4.5 Run scaled optimization step
|
562
|
+
self.optimizer.step()
|
563
|
+
# 5. Get float loss value for callbacks/writer
|
564
|
+
policy_loss_item = policy_loss.item()
|
565
|
+
|
566
|
+
# 6. Write to TensorBoard
|
567
|
+
self._rl_writer(policy_loss_item, epoch)
|
568
|
+
|
569
|
+
# 7. Run "on batch updated" callback
|
570
|
+
for cb in self.callbacks:
|
571
|
+
cb.on_batch_updated(self.actor, epoch, self.epoch_step['train'], policy_loss_item)
|
572
|
+
|
573
|
+
# 8. Return loss item
|
574
|
+
return policy_loss_item
|
575
|
+
|
576
|
+
def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
|
577
|
+
epoch: int, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
530
578
|
"""Perform PPO update step using trajectories."""
|
531
579
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
532
580
|
# memory, based on collected episode data
|
533
581
|
all_losses = []
|
534
|
-
|
582
|
+
critic_losses = []
|
535
583
|
for episode_idx, episode in enumerate(trajectories):
|
536
584
|
episode_steps = episode['steps']
|
537
585
|
should_reset_stm = episode['reset_stm']
|
538
586
|
|
539
|
-
# 2.
|
587
|
+
# 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
|
588
|
+
num_steps = len(episode_steps)
|
589
|
+
start = episode_idx * num_steps
|
590
|
+
end = start + num_steps
|
591
|
+
episode_critic_values = ref_values[start:end]
|
592
|
+
episode_advantages = advantages[start:end]
|
593
|
+
|
594
|
+
# 3. Reset memory for current batch episode
|
540
595
|
if should_reset_stm:
|
541
|
-
self.reset_stm()
|
596
|
+
self.reset_stm(force=True)
|
542
597
|
|
543
|
-
#
|
544
|
-
for step in episode_steps:
|
545
|
-
self._increment_steps('
|
546
|
-
|
598
|
+
# 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
|
599
|
+
for step_idx, step in enumerate(episode_steps):
|
600
|
+
self._increment_steps('train')
|
601
|
+
# 5. Get and move to device collected states, action and log probs
|
602
|
+
state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
|
547
603
|
query, answer, next_query = self._move_multiple_batches(*state)
|
548
604
|
action = self._move_batch(action)
|
549
605
|
log_probs = log_probs.to(self.device)
|
550
|
-
rewards = torch.tensor(reward).to(self.device)
|
551
|
-
|
552
|
-
# 4. Compute advantages using critic
|
553
|
-
if self.use_amp:
|
554
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
555
|
-
critic_state = smart_concat_critic_states(query, answer, next_query,
|
556
|
-
max_length=self.critic_max_len,
|
557
|
-
pad_token_id=self.pad_token_id)
|
558
|
-
advantages = self._critic_advantages(critic_state, rewards)
|
559
|
-
else:
|
560
|
-
critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
|
561
|
-
pad_token_id=self.pad_token_id)
|
562
|
-
advantages = self._critic_advantages(critic_state, rewards)
|
563
|
-
|
564
|
-
# 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
565
|
-
self.encode_and_update_stm(query, answer)
|
566
|
-
# 6. Concatenate next query and action and get action logits from decoder
|
567
|
-
if self.use_amp:
|
568
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
569
|
-
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
570
|
-
pad_token_id=self.pad_token_id)
|
571
|
-
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
572
|
-
action=MrlActorAction.DECODE)
|
573
|
-
else:
|
574
|
-
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
575
|
-
pad_token_id=self.pad_token_id)
|
576
|
-
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
577
|
-
action=MrlActorAction.DECODE)
|
578
|
-
|
579
|
-
# 7. Calculate RL Algorithm (PPO etc.) loss
|
580
|
-
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, log_probs, advantages)
|
581
|
-
|
582
|
-
# 8. Reset gradients
|
583
|
-
self.optimizer.zero_grad()
|
584
|
-
|
585
|
-
# 9. Update the model in AMP or regular mode
|
586
|
-
if self.use_amp:
|
587
|
-
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
588
|
-
self.scaler.unscale_(self.optimizer)
|
589
|
-
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
590
|
-
error_if_nonfinite=False)
|
591
|
-
self.scaler.step(self.optimizer)
|
592
|
-
self.scaler.update()
|
593
|
-
else:
|
594
|
-
policy_loss.backward(retain_graph=True)
|
595
|
-
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
596
|
-
error_if_nonfinite=False)
|
597
|
-
self.optimizer.step()
|
598
606
|
|
599
|
-
|
600
|
-
|
601
|
-
|
607
|
+
# 6. Select advantages and reference values for current step (batch_size)
|
608
|
+
step_critic_values = episode_critic_values[step_idx]
|
609
|
+
step_advantages = episode_advantages[step_idx]
|
602
610
|
|
603
|
-
#
|
604
|
-
|
605
|
-
|
611
|
+
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
612
|
+
if self.memory_aware_critic:
|
613
|
+
self.encode_and_update_stm(query, answer)
|
614
|
+
|
615
|
+
# 7. Update critic
|
616
|
+
critic_loss_item = self.update_critic((query, answer, next_query), step_critic_values, epoch)
|
606
617
|
|
607
|
-
|
618
|
+
# 8. Accumulate critic loss for epoch callbacks
|
619
|
+
critic_losses.append(critic_loss_item)
|
608
620
|
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
621
|
+
# 9. Update actor
|
622
|
+
policy_loss_item = self.update_actor((query, answer, next_query), action, step_advantages, log_probs,
|
623
|
+
epoch)
|
624
|
+
all_losses.append(policy_loss_item)
|
625
|
+
# 10. Return mean losses for epoch callbacks
|
626
|
+
return torch.mean(torch.tensor(all_losses)), torch.mean(torch.tensor(critic_losses))
|
627
|
+
|
628
|
+
def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
|
629
|
+
if self.memory_aware_critic:
|
630
|
+
flat_trajectories = [
|
631
|
+
(t, i == 0 and episode['reset_stm'])
|
632
|
+
for episode in trajectories
|
633
|
+
for i, t in enumerate(episode['steps'])
|
634
|
+
]
|
635
|
+
values = torch.stack([
|
636
|
+
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
|
637
|
+
]).to(self.device)
|
638
|
+
rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
|
639
|
+
dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
|
640
|
+
else:
|
641
|
+
flat_trajectories = [t for episode in trajectories for t in episode['steps']]
|
642
|
+
values = torch.stack([
|
643
|
+
self._critic_values(*self._move_multiple_batches(*t['state'])) for t in flat_trajectories
|
644
|
+
]).to(self.device)
|
645
|
+
rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
|
646
|
+
dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
|
647
|
+
return values, rewards, dones
|
648
|
+
|
649
|
+
def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
650
|
+
# 1. Calculate critic values in memory aware version - reset/update STM before calculating values
|
651
|
+
with torch.no_grad():
|
652
|
+
# 2. Reset STM if it was reset in trajectory collection
|
653
|
+
if reset_stm:
|
654
|
+
self.reset_stm(force=True)
|
655
|
+
# 3. Encode and update STM for critic
|
656
|
+
self.encode_and_update_stm(*moved_state)
|
657
|
+
# 4. Get concatenated critic states
|
658
|
+
inputs = smart_concat_critic_states(
|
659
|
+
*moved_state,
|
660
|
+
max_length=self.critic_max_len,
|
661
|
+
pad_token_id=self.pad_token_id,
|
662
|
+
)
|
663
|
+
# 5. Calculate values for current batch
|
664
|
+
return self.critic(inputs['input_ids'],
|
665
|
+
attention_mask=inputs['attention_mask']).squeeze()
|
666
|
+
|
667
|
+
def _critic_values(self, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
668
|
+
# 1. Calculate critic values
|
669
|
+
with torch.no_grad():
|
670
|
+
# 2. Get concatenated critic states
|
671
|
+
inputs = smart_concat_critic_states(
|
672
|
+
*moved_state,
|
673
|
+
max_length=self.critic_max_len,
|
674
|
+
pad_token_id=self.pad_token_id,
|
675
|
+
)
|
676
|
+
# 3. Calculate values for current batch
|
677
|
+
return self.critic(inputs['input_ids'],
|
678
|
+
attention_mask=inputs['attention_mask']).squeeze()
|
614
679
|
|
615
680
|
def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
|
616
681
|
"""Train for one epoch."""
|
617
682
|
# 1. Collect trajectories for current epoch
|
618
683
|
trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
|
619
684
|
|
620
|
-
# 2. Flatten trajectories and collect
|
621
|
-
|
622
|
-
|
623
|
-
|
685
|
+
# 2. Flatten trajectories, call critic and collect values, dones and rewards, and calculate advantages
|
686
|
+
if self.use_amp:
|
687
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
688
|
+
values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
|
689
|
+
advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
|
690
|
+
else:
|
691
|
+
values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
|
692
|
+
advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
|
693
|
+
|
694
|
+
# 3. Run internal update epochs
|
695
|
+
critic_loss_sum, policy_loss_sum = 0.0, 0.0
|
696
|
+
for update_epoch in range(self.update_epochs):
|
697
|
+
# 4. Run 'on update epoch start' callbacks
|
698
|
+
for cb in self.callbacks:
|
699
|
+
cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
|
700
|
+
|
701
|
+
# 5. Run RL algorithm step
|
702
|
+
policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
|
703
|
+
|
704
|
+
if self.use_ddp:
|
705
|
+
policy_loss = distributed_mean(policy_loss)
|
706
|
+
critic_loss = distributed_mean(critic_loss)
|
707
|
+
|
708
|
+
# 6. Run 'on update epoch end' callbacks
|
709
|
+
for cb in self.callbacks:
|
710
|
+
cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
|
624
711
|
|
625
|
-
|
626
|
-
|
712
|
+
# 7. Accumulate losses for epoch callbacks
|
713
|
+
critic_loss_sum += critic_loss
|
714
|
+
policy_loss_sum += policy_loss
|
627
715
|
|
628
|
-
#
|
629
|
-
return
|
716
|
+
# 8. Return policy and critic mean losses for epoch callbacks
|
717
|
+
return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
|
630
718
|
|
631
719
|
def _eval_loader(self, batch_size: int):
|
632
720
|
if self.use_ddp:
|
@@ -731,12 +819,11 @@ class MRLTrainer:
|
|
731
819
|
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
732
820
|
|
733
821
|
# 15. Calculate average reward
|
822
|
+
avg_reward = (total_reward / count) if count > 0 else torch.tensor(0.0).to(self.device)
|
734
823
|
if self.use_ddp:
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
else:
|
739
|
-
avg_reward = (total_reward / count).item() if count > 0 else 0
|
824
|
+
avg_reward = distributed_mean(avg_reward)
|
825
|
+
|
826
|
+
avg_reward = avg_reward.item()
|
740
827
|
|
741
828
|
should_stop_stage = False
|
742
829
|
# 16. Run "on eval end" callbacks
|
@@ -747,55 +834,6 @@ class MRLTrainer:
|
|
747
834
|
|
748
835
|
return should_stop_stage
|
749
836
|
|
750
|
-
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
751
|
-
# 1. Set common fields based on config
|
752
|
-
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
753
|
-
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
754
|
-
self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
|
755
|
-
self.callbacks = config.get('callbacks',
|
756
|
-
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
757
|
-
self.strategy = config.get('strategy',
|
758
|
-
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
759
|
-
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
760
|
-
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
761
|
-
if config.get('separate_memory_lr', False):
|
762
|
-
self.optim_config = {
|
763
|
-
'lr': config.get('lr', self.base_optim_config['lr']),
|
764
|
-
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
765
|
-
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
766
|
-
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
767
|
-
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
768
|
-
}
|
769
|
-
else:
|
770
|
-
self.optim_config = {
|
771
|
-
'lr': config.get('lr', self.base_optim_config['lr']),
|
772
|
-
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
773
|
-
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
774
|
-
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
775
|
-
}
|
776
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
777
|
-
elif self.optim_config != self.base_optim_config:
|
778
|
-
self.optim_config = self.base_optim_config
|
779
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
# 2. Get epochs and random resets configs
|
785
|
-
epochs = config.get('epochs', 5) # number of epochs for current stage
|
786
|
-
unfreeze_epoch = config.get('unfreeze_epoch',
|
787
|
-
0) # epoch when components (other than memory) are unfrozen (before epoch starts)
|
788
|
-
random_resets = config.get('random_resets',
|
789
|
-
False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
|
790
|
-
random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
|
791
|
-
random_resets_ratio = config.get('random_resets_ratio',
|
792
|
-
None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
|
793
|
-
|
794
|
-
# 3. Reset stage step counter
|
795
|
-
self.stage_step = self._init_steps()
|
796
|
-
|
797
|
-
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
798
|
-
|
799
837
|
def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
|
800
838
|
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
801
839
|
if is_staged_unfreeze:
|
@@ -808,28 +846,31 @@ class MRLTrainer:
|
|
808
846
|
self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
|
809
847
|
print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
|
810
848
|
elif epoch == update_epoch:
|
811
|
-
|
812
|
-
|
849
|
+
self.actor.freeze_components('update')
|
850
|
+
print(
|
851
|
+
f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
|
813
852
|
|
814
853
|
if isinstance(fetch_epoch, tuple):
|
815
854
|
switch_epoch, mem_att_lr = fetch_epoch
|
816
|
-
if epoch ==
|
855
|
+
if epoch == switch_epoch:
|
817
856
|
self.actor.freeze_components('joint')
|
818
857
|
self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
|
819
858
|
print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
|
820
859
|
elif epoch == fetch_epoch:
|
821
860
|
self.actor.freeze_components('fetch')
|
822
|
-
print(
|
861
|
+
print(
|
862
|
+
f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
|
823
863
|
|
824
864
|
if isinstance(joint_epoch, tuple):
|
825
865
|
switch_epoch, model_lr = joint_epoch
|
826
|
-
if epoch ==
|
866
|
+
if epoch == switch_epoch:
|
827
867
|
self.actor.unfreeze_components()
|
828
868
|
self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
|
829
869
|
print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
|
830
870
|
elif epoch == joint_epoch:
|
831
|
-
|
832
|
-
|
871
|
+
self.actor.freeze_components('joint')
|
872
|
+
print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
|
873
|
+
|
833
874
|
if epoch == all_epoch:
|
834
875
|
self.actor.unfreeze_components()
|
835
876
|
self.optimizer = self._init_unfreeze_optimizer('all', 0.)
|
@@ -871,6 +912,56 @@ class MRLTrainer:
|
|
871
912
|
|
872
913
|
return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
|
873
914
|
|
915
|
+
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[
|
916
|
+
tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
917
|
+
# 1. Set common fields based on config
|
918
|
+
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
919
|
+
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
920
|
+
self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
|
921
|
+
self.callbacks = config.get('callbacks',
|
922
|
+
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
923
|
+
self.strategy = config.get('strategy',
|
924
|
+
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
925
|
+
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
926
|
+
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
927
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
|
928
|
+
'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
929
|
+
if config.get('separate_memory_lr', False):
|
930
|
+
self.optim_config = {
|
931
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
932
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
933
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
934
|
+
'critic_weight_decay': config.get('critic_weight_decay',
|
935
|
+
self.base_optim_config['critic_weight_decay']),
|
936
|
+
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
937
|
+
}
|
938
|
+
else:
|
939
|
+
self.optim_config = {
|
940
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
941
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
942
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
943
|
+
'critic_weight_decay': config.get('critic_weight_decay',
|
944
|
+
self.base_optim_config['critic_weight_decay']),
|
945
|
+
}
|
946
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
947
|
+
elif self.optim_config != self.base_optim_config:
|
948
|
+
self.optim_config = self.base_optim_config
|
949
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
950
|
+
|
951
|
+
# 2. Get epochs and random resets configs
|
952
|
+
epochs = config.get('epochs', 5) # number of epochs for current stage
|
953
|
+
unfreeze_epoch = config.get('unfreeze_epoch',
|
954
|
+
0) # epoch when components (other than memory) are unfrozen (before epoch starts)
|
955
|
+
random_resets = config.get('random_resets',
|
956
|
+
False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
|
957
|
+
random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
|
958
|
+
random_resets_ratio = config.get('random_resets_ratio',
|
959
|
+
None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
|
960
|
+
|
961
|
+
# 3. Reset stage step counter
|
962
|
+
self.stage_step = self._init_steps()
|
963
|
+
|
964
|
+
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
874
965
|
|
875
966
|
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
876
967
|
"""Start Memory Reinforcement Learning Curriculum."""
|
@@ -881,29 +972,36 @@ class MRLTrainer:
|
|
881
972
|
|
882
973
|
# 1. Init DDP for distributed training mode
|
883
974
|
if self.use_ddp:
|
884
|
-
rank =
|
885
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
975
|
+
rank, world_size = get_os_ddp_config()
|
886
976
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
887
977
|
self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
|
888
978
|
self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
|
889
979
|
|
890
|
-
# 2.
|
980
|
+
# 2. Init BatchSampler with actor model (we have to run it after DDP init)
|
981
|
+
self.generator = BatchSampler(self.actor, self.device, end_token_id=self.end_token_id)
|
982
|
+
|
983
|
+
# 3. Run each curriculum step based on config
|
891
984
|
for current_curriculum_step in curriculum_config:
|
892
|
-
#
|
985
|
+
# 4. Setup training config for curriculum step
|
893
986
|
epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
|
894
987
|
epochs, unfreeze_epoch = epochs_config
|
895
988
|
random_resets, random_resets_from, random_resets_ratio = random_resets_config
|
896
989
|
assert self.train_dataset is not None
|
897
990
|
|
898
|
-
#
|
991
|
+
# 5. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
899
992
|
if unfreeze_epoch != 0:
|
900
|
-
|
901
|
-
|
902
|
-
print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
993
|
+
if callable(unfreeze_epoch):
|
994
|
+
unfreeze_epoch(-1)
|
903
995
|
else:
|
904
|
-
|
905
|
-
|
906
|
-
|
996
|
+
self.actor.freeze_components('joint')
|
997
|
+
if isinstance(unfreeze_epoch, tuple):
|
998
|
+
print(
|
999
|
+
f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
1000
|
+
else:
|
1001
|
+
print(
|
1002
|
+
f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
|
1003
|
+
|
1004
|
+
# 6. Setup train DataLoader
|
907
1005
|
if self.use_ddp:
|
908
1006
|
train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
|
909
1007
|
dataloader = DataLoader(
|
@@ -912,6 +1010,7 @@ class MRLTrainer:
|
|
912
1010
|
sampler=train_sampler,
|
913
1011
|
pin_memory=True,
|
914
1012
|
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
1013
|
+
drop_last=True,
|
915
1014
|
)
|
916
1015
|
else:
|
917
1016
|
train_sampler = None
|
@@ -923,65 +1022,68 @@ class MRLTrainer:
|
|
923
1022
|
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
924
1023
|
)
|
925
1024
|
|
926
|
-
#
|
1025
|
+
# 7. Run selected number of epochs for given curriculum stage
|
927
1026
|
for epoch in range(epochs):
|
928
|
-
#
|
1027
|
+
# 8. Increment global epoch
|
929
1028
|
self.global_epoch += 1
|
930
|
-
#
|
1029
|
+
# 9. Run "on epoch start" callbacks (log info, etc.)
|
931
1030
|
for cb in self.callbacks:
|
932
1031
|
cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
|
933
1032
|
self.global_epochs_count)
|
934
1033
|
|
935
|
-
#
|
1034
|
+
# 10. Reset steps counter for epoch
|
936
1035
|
self.epoch_step = self._init_steps()
|
937
1036
|
|
938
|
-
#
|
1037
|
+
# 11. Set random STM resets ratio from selected epoch
|
939
1038
|
if random_resets and random_resets_from <= epoch:
|
940
1039
|
self.random_resets_ratio = random_resets_ratio
|
941
1040
|
else:
|
942
1041
|
self.random_resets_ratio = 1.0
|
943
1042
|
|
944
|
-
#
|
945
|
-
|
1043
|
+
# 12. Apply the unfreeze strategy
|
1044
|
+
if callable(unfreeze_epoch):
|
1045
|
+
unfreeze_epoch(epoch)
|
1046
|
+
else:
|
1047
|
+
self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
|
946
1048
|
|
947
|
-
#
|
1049
|
+
# 13. Set epoch for distributed sampler
|
948
1050
|
if train_sampler is not None:
|
949
1051
|
train_sampler.set_epoch(epoch)
|
950
1052
|
|
951
|
-
#
|
1053
|
+
# 14. Run reinforcement learning algorithms for current epoch
|
952
1054
|
policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
|
953
1055
|
|
954
|
-
#
|
1056
|
+
# 15. If evaluation dataset is provided, run evaluation steps
|
955
1057
|
if self.eval_dataset:
|
956
1058
|
should_stop_stage = self.evaluate(batch_size, epoch)
|
957
1059
|
else:
|
958
1060
|
should_stop_stage = False
|
959
1061
|
|
960
|
-
#
|
1062
|
+
# 16. Finally, run "on epoch end" callbacks (save models, etc.)
|
961
1063
|
for cb in self.callbacks:
|
962
1064
|
cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
|
963
1065
|
self.global_epochs_count)
|
964
1066
|
|
965
|
-
#
|
1067
|
+
# 17. Synchronize TensorBoard writer
|
966
1068
|
if self.writer:
|
967
1069
|
self.writer.flush()
|
968
1070
|
|
969
|
-
#
|
1071
|
+
# 18. Synchronize devices in DDP mode
|
970
1072
|
if self.use_ddp:
|
971
1073
|
dist.barrier()
|
972
1074
|
|
973
|
-
#
|
1075
|
+
# 19. Finish curriculum stage if rewards are not increased or reached threshold point
|
974
1076
|
if should_stop_stage:
|
975
1077
|
break
|
976
1078
|
|
977
|
-
#
|
1079
|
+
# 20. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
|
978
1080
|
for cb in self.callbacks:
|
979
1081
|
cb.on_training_end(self.actor, self.critic, current_curriculum_step)
|
980
1082
|
|
981
|
-
#
|
1083
|
+
# 21. Training end - finish processes after all curriculum stages
|
982
1084
|
if self.use_ddp:
|
983
1085
|
dist.destroy_process_group()
|
984
1086
|
|
985
|
-
#
|
1087
|
+
# 22. Close writer
|
986
1088
|
if self.writer:
|
987
1089
|
self.writer.close()
|
rxnn/training/rl.py
CHANGED
@@ -2,8 +2,9 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import TypedDict
|
5
|
+
from typing import TypedDict, Optional
|
6
6
|
from .utils import TokenizedDict
|
7
|
+
from .ddp import distributed_mean
|
7
8
|
|
8
9
|
|
9
10
|
class RlAlgorithm(ABC):
|
@@ -17,21 +18,34 @@ class RlAlgorithm(ABC):
|
|
17
18
|
pass
|
18
19
|
|
19
20
|
@abstractmethod
|
20
|
-
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
21
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
21
22
|
pass
|
22
23
|
|
23
24
|
def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
24
25
|
return self.critic_loss(rewards, values)
|
25
26
|
|
27
|
+
|
26
28
|
class PPOConfig(TypedDict):
|
27
|
-
clip_eps: float
|
29
|
+
clip_eps: Optional[float]
|
30
|
+
gae_lambda: Optional[float]
|
31
|
+
gae_gamma: Optional[float]
|
32
|
+
entropy_coef: Optional[float]
|
33
|
+
use_distributed_advantage_norm: Optional[bool]
|
34
|
+
|
28
35
|
|
29
36
|
class PPOAlgorithm(RlAlgorithm):
|
30
|
-
def __init__(self, config: PPOConfig):
|
37
|
+
def __init__(self, config: Optional[PPOConfig] = None):
|
31
38
|
super(PPOAlgorithm, self).__init__()
|
32
39
|
|
40
|
+
if config is None:
|
41
|
+
config = {}
|
42
|
+
|
33
43
|
# PPO Config
|
34
44
|
self.clip_eps = config.get('clip_eps', 0.2)
|
45
|
+
self.gae_lambda = config.get('gae_lambda', 0.95)
|
46
|
+
self.gae_gamma = config.get('gae_gamma', 0.99)
|
47
|
+
self.entropy_coef = config.get('entropy_coef', 0.01)
|
48
|
+
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
35
49
|
|
36
50
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
37
51
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -78,11 +92,37 @@ class PPOAlgorithm(RlAlgorithm):
|
|
78
92
|
|
79
93
|
# d) Entropy bonus
|
80
94
|
entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
|
81
|
-
policy_loss -=
|
95
|
+
policy_loss -= self.entropy_coef * entropy
|
82
96
|
|
83
97
|
return policy_loss
|
84
98
|
|
85
|
-
def
|
86
|
-
|
87
|
-
|
88
|
-
|
99
|
+
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
100
|
+
T, B = rewards.shape
|
101
|
+
advantages = torch.zeros_like(rewards, device=values.device)
|
102
|
+
last_advantage = 0
|
103
|
+
last_value = next_value.detach()
|
104
|
+
|
105
|
+
for t in reversed(range(T)):
|
106
|
+
if t == T - 1:
|
107
|
+
next_values = last_value
|
108
|
+
else:
|
109
|
+
next_values = values[t + 1]
|
110
|
+
|
111
|
+
# Mask next values if episode ended
|
112
|
+
next_values = next_values * ~dones[t]
|
113
|
+
delta = rewards[t] + self.gae_gamma * next_values - values[t]
|
114
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
115
|
+
last_advantage = advantages[t]
|
116
|
+
|
117
|
+
returns = advantages + values
|
118
|
+
return advantages, returns
|
119
|
+
|
120
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
121
|
+
advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
|
122
|
+
if self.use_distributed_advantage_norm:
|
123
|
+
mean_advantage = distributed_mean(advantages.mean())
|
124
|
+
std_advantage = distributed_mean(advantages.std())
|
125
|
+
normalized_advantages = (advantages - mean_advantage) / (std_advantage + 1e-8)
|
126
|
+
else:
|
127
|
+
normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
128
|
+
return normalized_advantages, ref_values
|
@@ -11,14 +11,15 @@ rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
rxnn/training/base.py,sha256=
|
15
|
-
rxnn/training/bml.py,sha256=
|
16
|
-
rxnn/training/callbacks.py,sha256
|
14
|
+
rxnn/training/base.py,sha256=TGz_37RfI1qLI31GNRV5rLowW1kAHnJwqPm7DNfLfe4,11730
|
15
|
+
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
|
+
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
+
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
18
19
|
rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=Aimiiqf_4p6dp5Ty9pY9VwetySBS_OFpCQlcVHVkO4Q,55124
|
20
21
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=eL3C0yryiNBgl_xb-D-5dyYUtK4V4-K4t3a60x5ir28,5142
|
22
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.31.dist-info/METADATA,sha256=zxD2qPHL_QrFH1bYZrMv4odbXE4B_YIVEpGDzV2MYEI,25960
|
38
|
+
rxnn-0.2.31.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|