rxnn 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/callbacks.py +12 -0
- rxnn/training/mrl.py +255 -210
- rxnn/training/rl.py +32 -5
- {rxnn-0.2.29.dist-info → rxnn-0.2.30.dist-info}/METADATA +1 -1
- {rxnn-0.2.29.dist-info → rxnn-0.2.30.dist-info}/RECORD +7 -7
- {rxnn-0.2.29.dist-info → rxnn-0.2.30.dist-info}/LICENSE +0 -0
- {rxnn-0.2.29.dist-info → rxnn-0.2.30.dist-info}/WHEEL +0 -0
rxnn/training/callbacks.py
CHANGED
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
|
|
536
536
|
def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
537
|
pass
|
538
538
|
|
539
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
540
|
+
pass
|
541
|
+
|
539
542
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
540
543
|
pass
|
541
544
|
|
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
|
|
543
546
|
critic_loss: float) -> None:
|
544
547
|
pass
|
545
548
|
|
549
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
550
|
+
pass
|
551
|
+
|
546
552
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
547
553
|
pass
|
548
554
|
|
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
572
578
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
573
579
|
print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
|
574
580
|
|
581
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
582
|
+
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
|
+
|
575
584
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
576
585
|
print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
|
577
586
|
|
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
579
588
|
critic_loss: float) -> None:
|
580
589
|
print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
|
581
590
|
|
591
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
|
+
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
593
|
+
|
582
594
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
595
|
print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
584
596
|
|
rxnn/training/mrl.py
CHANGED
@@ -24,6 +24,7 @@ class MrlConfig(TypedDict):
|
|
24
24
|
critic_max_len: int
|
25
25
|
weight_decay: float
|
26
26
|
critic_weight_decay: float
|
27
|
+
update_epochs: int
|
27
28
|
|
28
29
|
|
29
30
|
class MrlStrategy(Enum):
|
@@ -31,9 +32,11 @@ class MrlStrategy(Enum):
|
|
31
32
|
MULTI_STEP_STRATEGY = 2
|
32
33
|
LONG_RANGE_STRATEGY = 3
|
33
34
|
|
35
|
+
|
34
36
|
UnfreezeItem = Union[int, tuple[int, float]]
|
35
37
|
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
|
36
38
|
|
39
|
+
|
37
40
|
class CurriculumConfig(TypedDict):
|
38
41
|
steps: int
|
39
42
|
epochs: int
|
@@ -52,6 +55,7 @@ class CurriculumConfig(TypedDict):
|
|
52
55
|
critic_lr: Optional[float]
|
53
56
|
weight_decay: Optional[float]
|
54
57
|
critic_weight_decay: Optional[float]
|
58
|
+
update_epochs: Optional[int]
|
55
59
|
|
56
60
|
|
57
61
|
class SamplerConfig(TypedDict):
|
@@ -66,6 +70,7 @@ class MrlTrajectoryStep(TypedDict):
|
|
66
70
|
log_probs: torch.Tensor
|
67
71
|
reward: list[float]
|
68
72
|
reference: TokenizedDict
|
73
|
+
done: bool
|
69
74
|
|
70
75
|
|
71
76
|
class MrlTrajectoryEpisode(TypedDict):
|
@@ -107,6 +112,9 @@ class MRLTrainer:
|
|
107
112
|
self.device = device
|
108
113
|
self.max_seq_len = config.get('max_seq_len', 256)
|
109
114
|
self.critic_max_len = config.get('critic_max_len', 512)
|
115
|
+
# Internal update epochs config
|
116
|
+
self.shared_update_epochs = config.get('update_epochs', 10)
|
117
|
+
self.update_epochs = self.shared_update_epochs
|
110
118
|
|
111
119
|
# Move models to device
|
112
120
|
if use_amp:
|
@@ -187,8 +195,8 @@ class MRLTrainer:
|
|
187
195
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
188
196
|
if memory_lr is not None:
|
189
197
|
optimizer = torch.optim.AdamW([
|
190
|
-
{
|
191
|
-
{
|
198
|
+
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
199
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
192
200
|
],
|
193
201
|
weight_decay=weight_decay,
|
194
202
|
)
|
@@ -207,11 +215,9 @@ class MRLTrainer:
|
|
207
215
|
|
208
216
|
return optimizer, critic_optimizer
|
209
217
|
|
210
|
-
|
211
218
|
def _init_steps(self):
|
212
219
|
return {
|
213
220
|
'collect': 0,
|
214
|
-
'critic': 0,
|
215
221
|
'rl': 0,
|
216
222
|
'eval': 0,
|
217
223
|
}
|
@@ -351,7 +357,7 @@ class MRLTrainer:
|
|
351
357
|
# state from existing one, instead of new random one)
|
352
358
|
reset_done = self.reset_stm()
|
353
359
|
|
354
|
-
# 4. Reset reward prev data running mean - it's calculated for
|
360
|
+
# 4. Reset reward prev data running mean - it's calculated for multistep retention, we have to reset it before episode
|
355
361
|
self.reward.reset_running_mean()
|
356
362
|
|
357
363
|
# 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
@@ -406,6 +412,7 @@ class MRLTrainer:
|
|
406
412
|
'log_probs': log_probs.detach().cpu(),
|
407
413
|
'reward': reward,
|
408
414
|
'reference': interaction['answer'],
|
415
|
+
'done': is_last_interaction,
|
409
416
|
}
|
410
417
|
episode_steps.append(trajectory)
|
411
418
|
episode_rewards.append(reward)
|
@@ -432,92 +439,23 @@ class MRLTrainer:
|
|
432
439
|
|
433
440
|
return trajectories
|
434
441
|
|
435
|
-
def _critic_loss(self, inputs: TokenizedDict,
|
442
|
+
def _critic_loss(self, inputs: TokenizedDict, ref_values: torch.Tensor) -> torch.Tensor:
|
436
443
|
# 1. Calculate values with critic encoder
|
437
444
|
values = self.critic(
|
438
445
|
inputs['input_ids'],
|
439
446
|
attention_mask=inputs['attention_mask'],
|
440
447
|
).squeeze()
|
441
448
|
# 2. Calculate critic loss
|
442
|
-
loss = self.rl_algorithm.critic_loss(values,
|
449
|
+
loss = self.rl_algorithm.critic_loss(values, ref_values)
|
443
450
|
return loss
|
444
451
|
|
445
452
|
def _critic_writer(self, critic_loss: float, epoch: int):
|
446
453
|
if self.writer is not None:
|
447
|
-
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['
|
454
|
+
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['rl'])
|
448
455
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
|
449
|
-
self.epoch_step['
|
456
|
+
self.epoch_step['rl'])
|
450
457
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
|
451
|
-
self.stage_step['
|
452
|
-
|
453
|
-
def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
|
454
|
-
rewards: list[torch.Tensor], epoch: int):
|
455
|
-
"""Update critic network using MSE loss."""
|
456
|
-
# 1. Run critic updates for all collected batches
|
457
|
-
critic_losses = []
|
458
|
-
for step_idx, (state, reward) in enumerate(zip(states, rewards)):
|
459
|
-
self._increment_steps('critic')
|
460
|
-
# 2. Move state batches to training device (GPU)
|
461
|
-
prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
|
462
|
-
|
463
|
-
# 3. Reset critic gradients
|
464
|
-
self.critic_optimizer.zero_grad()
|
465
|
-
|
466
|
-
# 4. Run critic and calculate loss - in autocast on/off mode
|
467
|
-
if self.use_amp:
|
468
|
-
# Move tensors to training device and calculate loss in autocast mode
|
469
|
-
batch_rewards = reward.to(self.device)
|
470
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
471
|
-
# Concatenate state into single critic input sequence
|
472
|
-
inputs = smart_concat_critic_states(
|
473
|
-
prev_query, prev_answer, next_query,
|
474
|
-
max_length=self.critic_max_len,
|
475
|
-
pad_token_id=self.pad_token_id,
|
476
|
-
)
|
477
|
-
loss = self._critic_loss(inputs, batch_rewards)
|
478
|
-
# Run backpropagation with scaler
|
479
|
-
self.critic_scaler.scale(loss).backward()
|
480
|
-
# Unscale and clip gradients
|
481
|
-
self.critic_scaler.unscale_(self.critic_optimizer)
|
482
|
-
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
483
|
-
# Run scaled optimization step
|
484
|
-
self.critic_scaler.step(self.critic_optimizer)
|
485
|
-
self.critic_scaler.update()
|
486
|
-
else:
|
487
|
-
# Concatenate state into single critic input sequence
|
488
|
-
inputs = smart_concat_critic_states(
|
489
|
-
prev_query, prev_answer, next_query,
|
490
|
-
max_length=self.critic_max_len,
|
491
|
-
pad_token_id=self.pad_token_id,
|
492
|
-
)
|
493
|
-
# Calculate loss
|
494
|
-
loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
|
495
|
-
# Run backpropagation
|
496
|
-
loss.backward()
|
497
|
-
# Clip gradients
|
498
|
-
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
499
|
-
# Run optimizer step
|
500
|
-
self.critic_optimizer.step()
|
501
|
-
critic_loss = loss.item()
|
502
|
-
self._critic_writer(critic_loss, epoch)
|
503
|
-
|
504
|
-
# 5. Run "on critic updated" callbacks
|
505
|
-
for cb in self.callbacks:
|
506
|
-
cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
|
507
|
-
|
508
|
-
# 6. Accumulate loss for epoch callbacks
|
509
|
-
critic_losses.append(critic_loss)
|
510
|
-
|
511
|
-
# 7. Calculate mean loss for epoch callbacks
|
512
|
-
critic_mean_loss = torch.tensor(critic_losses).mean().item()
|
513
|
-
|
514
|
-
return critic_mean_loss
|
515
|
-
|
516
|
-
def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
517
|
-
with torch.no_grad():
|
518
|
-
values = self.critic(critic_state['input_ids'],
|
519
|
-
attention_mask=critic_state['attention_mask']).squeeze()
|
520
|
-
return self.rl_algorithm.calculate_advantages(rewards, values)
|
458
|
+
self.stage_step['rl'])
|
521
459
|
|
522
460
|
def _rl_writer(self, policy_loss: float, epoch: int):
|
523
461
|
if self.writer is not None:
|
@@ -526,107 +464,208 @@ class MRLTrainer:
|
|
526
464
|
self.epoch_step['rl'])
|
527
465
|
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
|
528
466
|
|
529
|
-
def
|
467
|
+
def _update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
|
468
|
+
epoch: int) -> float:
|
469
|
+
# 1. Reset critic gradients
|
470
|
+
self.critic_optimizer.zero_grad()
|
471
|
+
|
472
|
+
# 2. Update critic - with autocast on/off
|
473
|
+
if self.use_amp:
|
474
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
475
|
+
# 2.1 Concat states and calculate critic loss
|
476
|
+
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
477
|
+
pad_token_id=self.pad_token_id)
|
478
|
+
critic_loss = self._critic_loss(critic_state, ref_values)
|
479
|
+
# 2.2 Run backpropagation with scaler
|
480
|
+
self.critic_scaler.scale(critic_loss).backward()
|
481
|
+
# 2.3 Unscale and clip gradients
|
482
|
+
self.critic_scaler.unscale_(self.critic_optimizer)
|
483
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
484
|
+
# 2.4 Run scaled optimization step
|
485
|
+
self.critic_scaler.step(self.critic_optimizer)
|
486
|
+
self.critic_scaler.update()
|
487
|
+
else:
|
488
|
+
# 2.1 Concat states and calculate critic loss
|
489
|
+
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
490
|
+
pad_token_id=self.pad_token_id)
|
491
|
+
critic_loss = self._critic_loss(critic_state, ref_values)
|
492
|
+
# 2.2 Run backpropagation
|
493
|
+
critic_loss.backward()
|
494
|
+
# 2.3 Clip gradients
|
495
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
496
|
+
# 2.4 Run optimizer step
|
497
|
+
self.critic_optimizer.step()
|
498
|
+
# 3. Get float loss value for callbacks/writer
|
499
|
+
critic_loss_item = critic_loss.item()
|
500
|
+
|
501
|
+
# 4. Write to TensorBoard
|
502
|
+
self._critic_writer(critic_loss_item, epoch)
|
503
|
+
|
504
|
+
# 5. Run "on critic updated" callbacks
|
505
|
+
for cb in self.callbacks:
|
506
|
+
cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['rl'], critic_loss_item)
|
507
|
+
# 6. Return loss item
|
508
|
+
return critic_loss_item
|
509
|
+
|
510
|
+
def _update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
511
|
+
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
512
|
+
# 1. Reset actor gradients
|
513
|
+
self.optimizer.zero_grad()
|
514
|
+
# 2. Unpack state dicts
|
515
|
+
query, answer, next_query = state
|
516
|
+
|
517
|
+
# 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
518
|
+
self.encode_and_update_stm(query, answer)
|
519
|
+
|
520
|
+
# 4. Update actor - with autocast on/off
|
521
|
+
if self.use_amp:
|
522
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
523
|
+
# 4.1 Concatenate next query and action and get action logits from decoder
|
524
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
525
|
+
pad_token_id=self.pad_token_id)
|
526
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
527
|
+
action=MrlActorAction.DECODE)
|
528
|
+
# 4.2 Calculate policy loss with selected algorithm
|
529
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
|
530
|
+
advantages)
|
531
|
+
# 4.3 Run backpropagation with scaler
|
532
|
+
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
533
|
+
# 4.4 Unscale and clip gradient norms
|
534
|
+
self.scaler.unscale_(self.optimizer)
|
535
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
536
|
+
error_if_nonfinite=False)
|
537
|
+
# 4.5 Run scaled optimization step
|
538
|
+
self.scaler.step(self.optimizer)
|
539
|
+
self.scaler.update()
|
540
|
+
else:
|
541
|
+
# 4.1 Concatenate next query and action and get action logits from decoder
|
542
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
543
|
+
pad_token_id=self.pad_token_id)
|
544
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
545
|
+
action=MrlActorAction.DECODE)
|
546
|
+
# 4.2 Calculate policy loss with selected algorithm
|
547
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
|
548
|
+
# 4.3 Run backpropagation
|
549
|
+
policy_loss.backward(retain_graph=True)
|
550
|
+
# 4.4 Clip gradient norms
|
551
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
552
|
+
error_if_nonfinite=False)
|
553
|
+
# 4.5 Run scaled optimization step
|
554
|
+
self.optimizer.step()
|
555
|
+
# 5. Get float loss value for callbacks/writer
|
556
|
+
policy_loss_item = policy_loss.item()
|
557
|
+
|
558
|
+
# 6. Write to TensorBoard
|
559
|
+
self._rl_writer(policy_loss_item, epoch)
|
560
|
+
|
561
|
+
# 7. Run "on batch updated" callback
|
562
|
+
for cb in self.callbacks:
|
563
|
+
cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
|
564
|
+
|
565
|
+
# 8. Return loss item
|
566
|
+
return policy_loss_item
|
567
|
+
|
568
|
+
def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
|
569
|
+
epoch: int, batch_size: int) -> tuple[float, float]:
|
530
570
|
"""Perform PPO update step using trajectories."""
|
531
571
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
532
572
|
# memory, based on collected episode data
|
533
573
|
all_losses = []
|
534
|
-
|
574
|
+
critic_losses = []
|
535
575
|
for episode_idx, episode in enumerate(trajectories):
|
536
576
|
episode_steps = episode['steps']
|
537
577
|
should_reset_stm = episode['reset_stm']
|
538
578
|
|
539
|
-
# 2.
|
579
|
+
# 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
|
580
|
+
start = episode_idx * episode_steps
|
581
|
+
end = start + episode_steps
|
582
|
+
episode_critic_values = ref_values[start:end]
|
583
|
+
episode_advantages = advantages[start:end]
|
584
|
+
|
585
|
+
# 3. Reset memory for current batch episode
|
540
586
|
if should_reset_stm:
|
541
587
|
self.reset_stm()
|
542
588
|
|
543
|
-
#
|
544
|
-
for step in episode_steps:
|
589
|
+
# 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
|
590
|
+
for step_idx, step in enumerate(episode_steps):
|
545
591
|
self._increment_steps('rl')
|
546
|
-
|
592
|
+
# 5. Get and move to device collected states, action and log probs
|
593
|
+
state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
|
547
594
|
query, answer, next_query = self._move_multiple_batches(*state)
|
548
595
|
action = self._move_batch(action)
|
549
596
|
log_probs = log_probs.to(self.device)
|
550
|
-
rewards = torch.tensor(reward).to(self.device)
|
551
|
-
|
552
|
-
# 4. Compute advantages using critic
|
553
|
-
if self.use_amp:
|
554
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
555
|
-
critic_state = smart_concat_critic_states(query, answer, next_query,
|
556
|
-
max_length=self.critic_max_len,
|
557
|
-
pad_token_id=self.pad_token_id)
|
558
|
-
advantages = self._critic_advantages(critic_state, rewards)
|
559
|
-
else:
|
560
|
-
critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
|
561
|
-
pad_token_id=self.pad_token_id)
|
562
|
-
advantages = self._critic_advantages(critic_state, rewards)
|
563
|
-
|
564
|
-
# 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
565
|
-
self.encode_and_update_stm(query, answer)
|
566
|
-
# 6. Concatenate next query and action and get action logits from decoder
|
567
|
-
if self.use_amp:
|
568
|
-
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
569
|
-
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
570
|
-
pad_token_id=self.pad_token_id)
|
571
|
-
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
572
|
-
action=MrlActorAction.DECODE)
|
573
|
-
else:
|
574
|
-
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
575
|
-
pad_token_id=self.pad_token_id)
|
576
|
-
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
577
|
-
action=MrlActorAction.DECODE)
|
578
|
-
|
579
|
-
# 7. Calculate RL Algorithm (PPO etc.) loss
|
580
|
-
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, log_probs, advantages)
|
581
|
-
|
582
|
-
# 8. Reset gradients
|
583
|
-
self.optimizer.zero_grad()
|
584
|
-
|
585
|
-
# 9. Update the model in AMP or regular mode
|
586
|
-
if self.use_amp:
|
587
|
-
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
588
|
-
self.scaler.unscale_(self.optimizer)
|
589
|
-
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
590
|
-
error_if_nonfinite=False)
|
591
|
-
self.scaler.step(self.optimizer)
|
592
|
-
self.scaler.update()
|
593
|
-
else:
|
594
|
-
policy_loss.backward(retain_graph=True)
|
595
|
-
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
596
|
-
error_if_nonfinite=False)
|
597
|
-
self.optimizer.step()
|
598
597
|
|
599
|
-
|
600
|
-
|
601
|
-
|
598
|
+
# 6. Select advantages and reference values for current step (batch_size)
|
599
|
+
step_critic_values = episode_critic_values[step_idx]
|
600
|
+
step_advantages = episode_advantages[step_idx]
|
602
601
|
|
603
|
-
#
|
604
|
-
|
605
|
-
cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
|
602
|
+
# 7. Update critic
|
603
|
+
critic_loss_item = self._update_critic((query, answer, next_query), step_critic_values, epoch)
|
606
604
|
|
607
|
-
|
605
|
+
# 8. Accumulate critic loss for epoch callbacks
|
606
|
+
critic_losses.append(critic_loss_item)
|
608
607
|
|
609
|
-
|
608
|
+
# 9. Update actor
|
609
|
+
policy_loss_item = self._update_actor((query, answer, next_query), action, step_advantages, log_probs,
|
610
|
+
epoch)
|
611
|
+
all_losses.append(policy_loss_item)
|
612
|
+
# 10. Return mean losses for epoch callbacks
|
613
|
+
return torch.mean(torch.tensor(all_losses)).item(), torch.mean(torch.tensor(critic_losses)).item()
|
614
|
+
|
615
|
+
def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
|
610
616
|
flat_trajectories = [t for episode in trajectories for t in episode['steps']]
|
611
|
-
|
612
|
-
|
613
|
-
|
617
|
+
values = [
|
618
|
+
self._critic_values(
|
619
|
+
smart_concat_critic_states(
|
620
|
+
*self._move_multiple_batches(*t['state']),
|
621
|
+
max_length=self.critic_max_len,
|
622
|
+
pad_token_id=self.pad_token_id,
|
623
|
+
)
|
624
|
+
) for t in flat_trajectories
|
625
|
+
]
|
626
|
+
values = torch.stack(values).to(self.device)
|
627
|
+
rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
|
628
|
+
dones = torch.stack([torch.BoolTensor(t['done']) for t in flat_trajectories]).to(self.device)
|
629
|
+
return values, rewards, dones
|
630
|
+
|
631
|
+
def _critic_values(self, inputs: TokenizedDict) -> torch.Tensor:
|
632
|
+
with torch.no_grad():
|
633
|
+
return self.critic(inputs['input_ids'],
|
634
|
+
attention_mask=inputs['attention_mask']).squeeze()
|
635
|
+
|
636
|
+
# return self.rl_algorithm.calculate_advantages(rewards, values)
|
614
637
|
|
615
638
|
def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
|
616
639
|
"""Train for one epoch."""
|
617
640
|
# 1. Collect trajectories for current epoch
|
618
641
|
trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
|
619
642
|
|
620
|
-
# 2. Flatten trajectories and collect
|
621
|
-
|
622
|
-
|
623
|
-
|
643
|
+
# 2. Flatten trajectories, call critic and collect values, dones and rewards, and calculate advantages
|
644
|
+
if self.use_amp:
|
645
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
646
|
+
values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
|
647
|
+
advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
|
648
|
+
else:
|
649
|
+
values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
|
650
|
+
advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
|
624
651
|
|
625
|
-
#
|
626
|
-
|
652
|
+
# 3. Run internal update epochs
|
653
|
+
critic_loss_sum, policy_loss_sum = 0.0, 0.0
|
654
|
+
for update_epoch in range(self.update_epochs):
|
655
|
+
# 4. Run 'on update epoch start' callbacks
|
656
|
+
for cb in self.callbacks:
|
657
|
+
cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
|
658
|
+
# 5. Run RL algorithm step
|
659
|
+
policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
|
627
660
|
|
628
|
-
|
629
|
-
|
661
|
+
for cb in self.callbacks:
|
662
|
+
cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
|
663
|
+
|
664
|
+
critic_loss_sum += critic_loss
|
665
|
+
policy_loss_sum += policy_loss
|
666
|
+
|
667
|
+
# 6. Return policy and critic mean losses for epoch callbacks
|
668
|
+
return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
|
630
669
|
|
631
670
|
def _eval_loader(self, batch_size: int):
|
632
671
|
if self.use_ddp:
|
@@ -747,55 +786,6 @@ class MRLTrainer:
|
|
747
786
|
|
748
787
|
return should_stop_stage
|
749
788
|
|
750
|
-
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
751
|
-
# 1. Set common fields based on config
|
752
|
-
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
753
|
-
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
754
|
-
self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
|
755
|
-
self.callbacks = config.get('callbacks',
|
756
|
-
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
757
|
-
self.strategy = config.get('strategy',
|
758
|
-
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
759
|
-
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
760
|
-
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
761
|
-
if config.get('separate_memory_lr', False):
|
762
|
-
self.optim_config = {
|
763
|
-
'lr': config.get('lr', self.base_optim_config['lr']),
|
764
|
-
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
765
|
-
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
766
|
-
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
767
|
-
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
768
|
-
}
|
769
|
-
else:
|
770
|
-
self.optim_config = {
|
771
|
-
'lr': config.get('lr', self.base_optim_config['lr']),
|
772
|
-
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
773
|
-
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
774
|
-
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
775
|
-
}
|
776
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
777
|
-
elif self.optim_config != self.base_optim_config:
|
778
|
-
self.optim_config = self.base_optim_config
|
779
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
# 2. Get epochs and random resets configs
|
785
|
-
epochs = config.get('epochs', 5) # number of epochs for current stage
|
786
|
-
unfreeze_epoch = config.get('unfreeze_epoch',
|
787
|
-
0) # epoch when components (other than memory) are unfrozen (before epoch starts)
|
788
|
-
random_resets = config.get('random_resets',
|
789
|
-
False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
|
790
|
-
random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
|
791
|
-
random_resets_ratio = config.get('random_resets_ratio',
|
792
|
-
None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
|
793
|
-
|
794
|
-
# 3. Reset stage step counter
|
795
|
-
self.stage_step = self._init_steps()
|
796
|
-
|
797
|
-
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
798
|
-
|
799
789
|
def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
|
800
790
|
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
801
791
|
if is_staged_unfreeze:
|
@@ -808,28 +798,31 @@ class MRLTrainer:
|
|
808
798
|
self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
|
809
799
|
print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
|
810
800
|
elif epoch == update_epoch:
|
811
|
-
|
812
|
-
|
801
|
+
self.actor.freeze_components('update')
|
802
|
+
print(
|
803
|
+
f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
|
813
804
|
|
814
805
|
if isinstance(fetch_epoch, tuple):
|
815
806
|
switch_epoch, mem_att_lr = fetch_epoch
|
816
|
-
if epoch ==
|
807
|
+
if epoch == switch_epoch:
|
817
808
|
self.actor.freeze_components('joint')
|
818
809
|
self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
|
819
810
|
print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
|
820
811
|
elif epoch == fetch_epoch:
|
821
812
|
self.actor.freeze_components('fetch')
|
822
|
-
print(
|
813
|
+
print(
|
814
|
+
f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
|
823
815
|
|
824
816
|
if isinstance(joint_epoch, tuple):
|
825
817
|
switch_epoch, model_lr = joint_epoch
|
826
|
-
if epoch ==
|
818
|
+
if epoch == switch_epoch:
|
827
819
|
self.actor.unfreeze_components()
|
828
820
|
self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
|
829
821
|
print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
|
830
822
|
elif epoch == joint_epoch:
|
831
|
-
|
832
|
-
|
823
|
+
self.actor.freeze_components('joint')
|
824
|
+
print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
|
825
|
+
|
833
826
|
if epoch == all_epoch:
|
834
827
|
self.actor.unfreeze_components()
|
835
828
|
self.optimizer = self._init_unfreeze_optimizer('all', 0.)
|
@@ -871,6 +864,56 @@ class MRLTrainer:
|
|
871
864
|
|
872
865
|
return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
|
873
866
|
|
867
|
+
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[
|
868
|
+
tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
869
|
+
# 1. Set common fields based on config
|
870
|
+
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
871
|
+
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
872
|
+
self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
|
873
|
+
self.callbacks = config.get('callbacks',
|
874
|
+
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
875
|
+
self.strategy = config.get('strategy',
|
876
|
+
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
877
|
+
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
878
|
+
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
879
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
|
880
|
+
'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
881
|
+
if config.get('separate_memory_lr', False):
|
882
|
+
self.optim_config = {
|
883
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
884
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
885
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
886
|
+
'critic_weight_decay': config.get('critic_weight_decay',
|
887
|
+
self.base_optim_config['critic_weight_decay']),
|
888
|
+
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
889
|
+
}
|
890
|
+
else:
|
891
|
+
self.optim_config = {
|
892
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
893
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
894
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
895
|
+
'critic_weight_decay': config.get('critic_weight_decay',
|
896
|
+
self.base_optim_config['critic_weight_decay']),
|
897
|
+
}
|
898
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
899
|
+
elif self.optim_config != self.base_optim_config:
|
900
|
+
self.optim_config = self.base_optim_config
|
901
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
902
|
+
|
903
|
+
# 2. Get epochs and random resets configs
|
904
|
+
epochs = config.get('epochs', 5) # number of epochs for current stage
|
905
|
+
unfreeze_epoch = config.get('unfreeze_epoch',
|
906
|
+
0) # epoch when components (other than memory) are unfrozen (before epoch starts)
|
907
|
+
random_resets = config.get('random_resets',
|
908
|
+
False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
|
909
|
+
random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
|
910
|
+
random_resets_ratio = config.get('random_resets_ratio',
|
911
|
+
None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
|
912
|
+
|
913
|
+
# 3. Reset stage step counter
|
914
|
+
self.stage_step = self._init_steps()
|
915
|
+
|
916
|
+
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
874
917
|
|
875
918
|
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
876
919
|
"""Start Memory Reinforcement Learning Curriculum."""
|
@@ -899,9 +942,11 @@ class MRLTrainer:
|
|
899
942
|
if unfreeze_epoch != 0:
|
900
943
|
self.actor.freeze_components('joint')
|
901
944
|
if isinstance(unfreeze_epoch, tuple):
|
902
|
-
print(
|
945
|
+
print(
|
946
|
+
f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
903
947
|
else:
|
904
|
-
print(
|
948
|
+
print(
|
949
|
+
f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
|
905
950
|
|
906
951
|
# 5. Setup train DataLoader
|
907
952
|
if self.use_ddp:
|
rxnn/training/rl.py
CHANGED
@@ -17,7 +17,7 @@ class RlAlgorithm(ABC):
|
|
17
17
|
pass
|
18
18
|
|
19
19
|
@abstractmethod
|
20
|
-
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
20
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
21
21
|
pass
|
22
22
|
|
23
23
|
def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
@@ -25,6 +25,9 @@ class RlAlgorithm(ABC):
|
|
25
25
|
|
26
26
|
class PPOConfig(TypedDict):
|
27
27
|
clip_eps: float
|
28
|
+
gae_lambda: float
|
29
|
+
gae_gamma: float
|
30
|
+
entropy_coef: float
|
28
31
|
|
29
32
|
class PPOAlgorithm(RlAlgorithm):
|
30
33
|
def __init__(self, config: PPOConfig):
|
@@ -32,6 +35,9 @@ class PPOAlgorithm(RlAlgorithm):
|
|
32
35
|
|
33
36
|
# PPO Config
|
34
37
|
self.clip_eps = config.get('clip_eps', 0.2)
|
38
|
+
self.gae_lambda = config.get('gae_lambda', 0.95)
|
39
|
+
self.gae_gamma = config.get('gae_gamma', 0.99)
|
40
|
+
self.entropy_coef = config.get('entropy_coef', 0.01)
|
35
41
|
|
36
42
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
37
43
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -78,11 +84,32 @@ class PPOAlgorithm(RlAlgorithm):
|
|
78
84
|
|
79
85
|
# d) Entropy bonus
|
80
86
|
entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
|
81
|
-
policy_loss -=
|
87
|
+
policy_loss -= self.entropy_coef * entropy
|
82
88
|
|
83
89
|
return policy_loss
|
84
90
|
|
85
|
-
def
|
86
|
-
|
91
|
+
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
92
|
+
T, B = rewards.shape
|
93
|
+
advantages = torch.zeros_like(rewards, device=values.device)
|
94
|
+
last_advantage = 0
|
95
|
+
last_value = next_value.detach()
|
96
|
+
|
97
|
+
for t in reversed(range(T)):
|
98
|
+
if t == T - 1:
|
99
|
+
next_values = last_value
|
100
|
+
else:
|
101
|
+
next_values = values[t + 1]
|
102
|
+
|
103
|
+
# Mask next values if episode ended
|
104
|
+
next_values = next_values * (1 - dones[t])
|
105
|
+
delta = rewards[t] + self.gae_gamma * next_values - values[t]
|
106
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
107
|
+
last_advantage = advantages[t]
|
108
|
+
|
109
|
+
returns = advantages + values
|
110
|
+
return advantages, returns
|
111
|
+
|
112
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
113
|
+
advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
|
87
114
|
normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
88
|
-
return normalized_advantages
|
115
|
+
return normalized_advantages, ref_values
|
@@ -13,12 +13,12 @@ rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
|
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
|
-
rxnn/training/callbacks.py,sha256
|
16
|
+
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=DGevQoimkB9qBEkqIw1kkh5DfLqBM-XGvFkraqh-uYk,51545
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=U-mlTK2hF0wZQslzjlvF4S_sMkeTuSqKsCB3IWEsd2A,4558
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.30.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.30.dist-info/METADATA,sha256=zRJ_oHLqUD0QDKJoGRJ6FH5MC-y0k8nOn_inZ_iEP8c,25960
|
37
|
+
rxnn-0.2.30.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|