rxnn 0.2.47__py3-none-any.whl → 0.2.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/norm.py CHANGED
@@ -163,7 +163,7 @@ def init_memory_norm(
163
163
  init_scale: float = 1.0,
164
164
  per_dim_scale: bool = False,
165
165
  ) -> nn.Module:
166
- assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
166
+ assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
167
167
  if norm_type == 'layer':
168
168
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
169
169
  elif norm_type == 'rms':
@@ -172,4 +172,6 @@ def init_memory_norm(
172
172
  return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
173
173
  elif norm_type == 'positional':
174
174
  return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
175
+ elif norm_type == 'classic-rms':
176
+ return nn.RMSNorm(dim)
175
177
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
rxnn/training/mrl.py CHANGED
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
35
35
  moe_aux_loss_scale: Optional[float]
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
+ use_memory_warmup: Optional[bool]
38
39
 
39
40
 
40
41
  class MrlStrategy(Enum):
@@ -70,6 +71,7 @@ class CurriculumConfig(TypedDict):
70
71
  update_epochs: Optional[int]
71
72
  freeze_embeddings: Optional[bool]
72
73
  embedding_lr: Optional[float]
74
+ teacher_forcing: Optional[bool]
73
75
 
74
76
 
75
77
  class SamplerConfig(TypedDict):
@@ -136,6 +138,7 @@ class MRLTrainer:
136
138
  self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
137
139
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
138
140
  self.freeze_embeddings = self.shared_freeze_embeddings
141
+ self.use_memory_warmup = config.get('use_memory_warmup', False)
139
142
  # Internal update epochs config
140
143
  self.shared_update_epochs = config.get('update_epochs', 10)
141
144
  self.update_epochs = self.shared_update_epochs
@@ -213,6 +216,7 @@ class MRLTrainer:
213
216
  self.callbacks = []
214
217
  self.global_epoch = 0
215
218
  self.global_epochs_count = 0
219
+ self.teacher_forcing = False
216
220
 
217
221
  def _init_optimizers(
218
222
  self,
@@ -381,6 +385,11 @@ class MRLTrainer:
381
385
  self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
382
386
  self.stage_step['collect'])
383
387
 
388
+ def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
389
+ if self.use_memory_warmup:
390
+ with torch.no_grad():
391
+ self.encode_and_update_stm(query, answer)
392
+
384
393
  def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
385
394
  """Collect trajectories for PPO for current curriculum step."""
386
395
  # 1. Init trajectories list
@@ -402,8 +411,13 @@ class MRLTrainer:
402
411
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
403
412
  interactions = interactions[:self.curriculum_steps]
404
413
  interactions_len = len(interactions)
414
+
415
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
416
+
417
+ if reset_done:
418
+ self.memory_warmup(*first_interaction)
405
419
  # 6. Encode and update STM with data to save from first interaction
406
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
420
+ self.encode_and_update_stm(*first_interaction)
407
421
 
408
422
  # 7. Save first interaction as data to save (for trajectory state)
409
423
  query, answer = first_query, first_answer
@@ -440,8 +454,10 @@ class MRLTrainer:
440
454
 
441
455
  # 11. Update STM with generated response (except last interaction, it's not needed)
442
456
  if not is_last_interaction:
443
- self.encode_and_update_stm(next_query,
444
- generated_answer) # update with generated_answer on GPU
457
+ self.encode_and_update_stm(
458
+ next_query,
459
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
460
+ ) # update with generated_answer on GPU
445
461
 
446
462
  # 12. Store trajectory step
447
463
  trajectory: MrlTrajectoryStep = {
@@ -458,7 +474,7 @@ class MRLTrainer:
458
474
  # 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
459
475
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
460
476
  prev_interaction = (query, answer)
461
- query, answer = interaction['query'], detached_answer
477
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
462
478
 
463
479
  # 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
464
480
  episode_trajectory: MrlTrajectoryEpisode = {
@@ -649,6 +665,9 @@ class MRLTrainer:
649
665
 
650
666
  self.actor.clone_reset_memory()
651
667
 
668
+ if should_reset_stm and step_idx == 0:
669
+ self.memory_warmup(query, answer)
670
+
652
671
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
653
672
  if self.memory_aware_critic:
654
673
  self.encode_and_update_stm(query, answer)
@@ -798,13 +817,16 @@ class MRLTrainer:
798
817
  if batch['query']['input_ids'].size(0) == batch_size:
799
818
  self._increment_steps('eval')
800
819
  # 3. Reset STM with random resets ratio and reward model running mean
801
- self.reset_stm()
820
+ reset_stm = self.reset_stm()
802
821
  self.reward.reset_running_mean()
803
822
 
804
823
  # 4. Get batches for first queries, answers and all follow-up interactions
805
824
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
806
825
  # 5. Encode and update STM with initial interactions (batch)
807
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
826
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
827
+ if reset_stm:
828
+ self.memory_warmup(*first_interaction)
829
+ self.encode_and_update_stm(*first_interaction)
808
830
 
809
831
  # 6. Save follow-up interactions len and first query and answer as previous one for iteration
810
832
  interactions_len = len(interactions)
@@ -839,7 +861,10 @@ class MRLTrainer:
839
861
 
840
862
  # 10. Encode and update memory for the next interaction
841
863
  if not is_last_interaction:
842
- self.encode_and_update_stm(next_query, generated_answer)
864
+ self.encode_and_update_stm(
865
+ next_query,
866
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
867
+ )
843
868
 
844
869
  # 11. Accumulate rewards
845
870
  step_reward = torch.tensor(reward).mean().to(self.device)
@@ -852,7 +877,7 @@ class MRLTrainer:
852
877
  # 12. Save previous interaction
853
878
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
854
879
  prev_interaction = (query, answer)
855
- query, answer = interaction['query'], detached_answer
880
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
856
881
  avg_episode_reward = (episode_reward / episode_interactions).item()
857
882
  # 13. Run eval TensorBoard writer with average episode reward
858
883
  self._eval_writer(avg_episode_reward, epoch)
@@ -982,8 +1007,7 @@ class MRLTrainer:
982
1007
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
983
1008
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
984
1009
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
985
-
986
-
1010
+ self.teacher_forcing = config.get('teacher_forcing', False)
987
1011
 
988
1012
  def has_param(field: OptimField) -> bool:
989
1013
  return field in config and config[field] is not None
rxnn/training/reward.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
4
5
  from enum import Enum
5
6
  from typing import Optional
6
7
  from .utils import TokenizedDict
@@ -37,6 +38,7 @@ class MrlRewardModel:
37
38
  reward_len: bool = False,
38
39
  neg_reward_len: bool = False,
39
40
  max_rewarded_len: int = None,
41
+ target_len_as_ref: bool = False,
40
42
  len_factor: int = None,
41
43
  use_running_mean: bool = True,
42
44
  running_mean_decay: float = 0.2,
@@ -44,6 +46,7 @@ class MrlRewardModel:
44
46
  bleu_ref_weights: tuple = (0.5, 0.5),
45
47
  tanh_reward_scale: bool = False,
46
48
  rewards_scale: float = 1.0,
49
+ debug_mode: int = 0,
47
50
  ):
48
51
  self.shared_embedding = shared_embedding.to(device)
49
52
  self.device = device
@@ -67,6 +70,7 @@ class MrlRewardModel:
67
70
  self.reward_len = reward_len
68
71
  self.neg_reward_len = neg_reward_len
69
72
  self.max_rewarded_len = max_rewarded_len
73
+ self.target_len_as_ref = target_len_as_ref
70
74
  self.len_factor = len_factor
71
75
  self.use_running_mean = use_running_mean
72
76
  self.running_mean_decay = running_mean_decay
@@ -74,6 +78,8 @@ class MrlRewardModel:
74
78
  self.bleu_saved_weights = bleu_saved_weights
75
79
  self.tanh_reward_scale = tanh_reward_scale
76
80
  self.rewards_scale = rewards_scale
81
+ self.bleu_smoothing = SmoothingFunction().method4
82
+ self.debug_mode = debug_mode
77
83
 
78
84
  self.prev_data_running_mean = None
79
85
 
@@ -95,59 +101,133 @@ class MrlRewardModel:
95
101
  assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
96
102
  assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
97
103
 
98
- def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
99
- from nltk.translate.bleu_score import sentence_bleu
104
+ def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
105
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
106
+ generated, reference, saved_data = input_ids
107
+ generated_mask, reference_mask, saved_data_mask = masks
108
+
109
+ generated = generated.tolist()[:generated_mask.sum().item()]
110
+ reference = reference.tolist()[:reference_mask.sum().item()]
111
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
112
+
113
+ if self.debug_mode == 2:
114
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
100
115
 
101
116
  if self.bleu_with_saved_data:
102
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
103
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
117
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
118
+ smoothing_function=self.bleu_smoothing)
119
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
120
+ smoothing_function=self.bleu_smoothing)
121
+ if self.debug_mode == 2:
122
+ print('REF BLEU: ', ref_bleu)
123
+ print('SAVED BLEU: ', saved_bleu)
124
+
104
125
  return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
105
126
  else:
106
127
  return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
107
128
 
129
+ def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
130
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
131
+ generated, reference, saved_data = input_ids
132
+ generated_mask, reference_mask, saved_data_mask = masks
133
+
134
+ generated = generated.tolist()[:generated_mask.sum().item()]
135
+ reference = reference.tolist()[:reference_mask.sum().item()]
136
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
108
137
 
109
- def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
110
- saved_data: torch.Tensor) -> float:
111
- from nltk.translate.bleu_score import sentence_bleu
138
+ if self.debug_mode == 2:
139
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
112
140
 
113
141
  if self.bleu_with_saved_data:
114
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
115
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
142
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
143
+ smoothing_function=self.bleu_smoothing)
144
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
145
+ smoothing_function=self.bleu_smoothing)
116
146
  saved_bleu = 1 - saved_bleu
117
147
 
148
+ if self.debug_mode == 2:
149
+ print('REF BLEU: ', ref_bleu)
150
+ print('SAVED BLEU: ', saved_bleu)
151
+
118
152
  return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
119
153
  else:
120
154
  return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
121
155
 
122
- def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
123
- batch_size = generated.size(0)
124
- return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
156
+ def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
157
+ batch_size = generated['input_ids'].size(0)
158
+
159
+ return [
160
+ self._sentence_bleu(
161
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
162
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
163
+ ) for i in range(batch_size)
164
+ ]
165
+
166
+ def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
167
+ float]:
168
+ batch_size = generated['input_ids'].size(0)
169
+
170
+ return [
171
+ self._negative_sentence_bleu(
172
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
173
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
174
+ ) for i in range(batch_size)
175
+ ]
176
+
177
+ def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
178
+ input_ids = sequence['input_ids']
179
+ attention_mask = sequence['attention_mask']
180
+
181
+ # Get embeddings
182
+ embeddings = self.shared_embedding(input_ids.to(self.device))
183
+
184
+ # Apply attention mask
185
+ mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
186
+ masked_embeddings = embeddings * mask_expanded
187
+
188
+ # Compute mean with masking
189
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
190
+ token_counts = torch.sum(mask_expanded, dim=1)
191
+ token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
192
+
193
+ return sum_embeddings / token_counts
125
194
 
126
- def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
127
- embedding = self.shared_embedding(sequence.to(self.device))
128
- return embedding.mean(dim=1)
195
+ def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
196
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
197
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
198
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
129
199
 
130
- def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
131
- generated_emb = self._sequence_embedding(generated)
200
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
201
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
132
202
 
133
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
134
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
135
- return gen_and_saved, gen_and_ref
203
+ if self.debug_mode >= 1:
204
+ print('GEN AND SAVED: ', gen_and_saved.mean())
205
+ print('GEN AND REF: ', gen_and_ref.mean())
206
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
136
207
 
137
- def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
138
- generated_emb = self._sequence_embedding(generated)
208
+ def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
209
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
210
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
211
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
212
+ running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
139
213
 
140
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
141
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
142
- gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
143
- return gen_and_saved, gen_and_ref, gen_and_mean
214
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
215
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
216
+ gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
144
217
 
145
- def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
218
+ if self.debug_mode >= 1:
219
+ print('GEN AND SAVED: ', gen_and_saved.mean())
220
+ print('GEN AND REF: ', gen_and_ref.mean())
221
+ print('GEN AND MEAN: ', gen_and_mean.mean())
222
+
223
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
224
+
225
+ def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
146
226
  include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
147
227
  if self.use_running_mean and negative_running_mean:
148
228
  gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
149
229
  return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
150
- 1 - gen_and_mean)
230
+ 1 - gen_and_mean)
151
231
  elif self.use_running_mean and include_running_mean:
152
232
  gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
153
233
  return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
@@ -155,25 +235,26 @@ class MrlRewardModel:
155
235
  gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
156
236
  return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
157
237
 
158
- def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
159
- saved_data: torch.Tensor) -> torch.Tensor:
238
+ def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
239
+ saved_data: TokenizedDict) -> torch.Tensor:
160
240
  gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
161
241
 
162
242
  return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
163
243
 
164
- def len_reward(self, generated: TokenizedDict):
244
+ def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
245
+ target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
165
246
  lens = generated['attention_mask'].sum(dim=1)
166
- neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
167
- len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
247
+ neg_lens = target_lens / lens if self.neg_reward_len else 1.0
248
+ len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
168
249
  return len_reward
169
250
 
170
251
  def reset_running_mean(self):
171
252
  self.prev_data_running_mean = None
172
253
 
173
- def init_running_mean(self, prev_data: torch.Tensor):
254
+ def init_running_mean(self, prev_data: TokenizedDict):
174
255
  self.prev_data_running_mean = self._sequence_embedding(prev_data)
175
256
 
176
- def update_running_mean(self, prev_data: torch.Tensor):
257
+ def update_running_mean(self, prev_data: TokenizedDict):
177
258
  self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
178
259
  prev_data) + self.running_mean_decay * self.prev_data_running_mean
179
260
 
@@ -193,24 +274,50 @@ class MrlRewardModel:
193
274
  ) -> list[float]:
194
275
  if prev_data is not None:
195
276
  if self.prev_data_running_mean is None:
196
- self.init_running_mean(prev_data['input_ids'])
277
+ self.init_running_mean(prev_data)
197
278
  else:
198
- self.update_running_mean(prev_data['input_ids'])
279
+ self.update_running_mean(prev_data)
199
280
 
200
281
  if mode == MrlRewardMode.STANDARD:
201
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
202
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
203
- include_running_mean=prev_data is not None)
282
+ bleu = self.batch_bleu(generated, reference, saved_data)
283
+ cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
284
+
285
+ if self.debug_mode >= 1:
286
+ print('STANDARD MODE')
287
+ print('BLEU: ', sum(bleu) / len(bleu))
288
+ print('COSINE: ', sum(cosine) / len(cosine))
289
+
204
290
  sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
205
291
  elif mode == MrlRewardMode.LONG_RANGE:
206
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
207
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
292
+ bleu = self.batch_bleu(generated, reference, saved_data)
293
+ cosine = self.batch_cosine(generated, reference, saved_data,
208
294
  negative_running_mean=prev_data is not None)
295
+
296
+ if self.debug_mode >= 1:
297
+ print('LONG MODE')
298
+ print('BLEU: ', sum(bleu) / len(bleu))
299
+ print('COSINE: ', sum(cosine) / len(cosine))
300
+
209
301
  sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
210
302
  else:
211
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
212
- cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
303
+ bleu = self.negative_bleu(generated, reference, saved_data)
304
+ cosine = self.negative_cosine(generated, reference, saved_data)
305
+
306
+ if self.debug_mode >= 1:
307
+ print('NEGATIVE MODE')
308
+ print('BLEU: ', sum(bleu) / len(bleu))
309
+ print('COSINE: ', sum(cosine) / len(cosine))
310
+
213
311
  sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
214
312
 
215
- rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
313
+ if self.reward_len:
314
+ len_reward = self.len_reward(generated, reference)
315
+
316
+ if self.debug_mode >= 1:
317
+ print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
318
+
319
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
320
+ else:
321
+ rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
322
+
216
323
  return rewards.tolist()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.47
3
+ Version: 0.2.49
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
16
  Requires-Dist: datasets (>=3.5.0,<4.0.0)
17
17
  Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
18
+ Requires-Dist: nltk (>=3.9.1,<4.0.0)
18
19
  Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
19
20
  Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
20
21
  Requires-Dist: torch (>=2.6.0,<3.0.0)
@@ -6,7 +6,7 @@ rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=kan6UNPTjLfO7zKNp92hGooldgWPi3li_2-_L5xiErs,2784
9
- rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
9
+ rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=new_YXLe9vfIBPX-pmFRoV523d7yCjEgfTY06EaH3Ms,14605
@@ -17,8 +17,8 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=L2emJM06u7B9f9T1dFsGXzXX-rsV77ND7L1pAM9Z_Ow,9051
20
- rxnn/training/mrl.py,sha256=VXwRJ4wQtE0OoRsrsjYlWa2toTvHjoBJ_kril3EiK_A,59811
21
- rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
20
+ rxnn/training/mrl.py,sha256=IOi_xbQ47RPgv_2ucT9EkPeWLGBRlgPxKHFeQsYc3Pw,61074
21
+ rxnn/training/reward.py,sha256=dq3b5DRhBLHOvtlHX3eSSuxYBGYCyV5gVqbzCam4uP8,16112
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.47.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.47.dist-info/METADATA,sha256=OqRYFY68bnqQXdXfBNboYLAmXRmojMmR1YFUVQa4Jgo,25960
38
- rxnn-0.2.47.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.47.dist-info/RECORD,,
36
+ rxnn-0.2.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.49.dist-info/METADATA,sha256=Yd5xCJVA_rFdzYkTkHZ8tyronArNMOgUQ6VqNF9-vqs,25997
38
+ rxnn-0.2.49.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.49.dist-info/RECORD,,
File without changes
File without changes