rxnn 0.2.48__py3-none-any.whl → 0.2.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/mrl.py +13 -7
- rxnn/training/reward.py +153 -46
- {rxnn-0.2.48.dist-info → rxnn-0.2.50.dist-info}/METADATA +2 -1
- {rxnn-0.2.48.dist-info → rxnn-0.2.50.dist-info}/RECORD +6 -6
- {rxnn-0.2.48.dist-info → rxnn-0.2.50.dist-info}/LICENSE +0 -0
- {rxnn-0.2.48.dist-info → rxnn-0.2.50.dist-info}/WHEEL +0 -0
rxnn/training/mrl.py
CHANGED
@@ -71,6 +71,7 @@ class CurriculumConfig(TypedDict):
|
|
71
71
|
update_epochs: Optional[int]
|
72
72
|
freeze_embeddings: Optional[bool]
|
73
73
|
embedding_lr: Optional[float]
|
74
|
+
teacher_forcing: Optional[bool]
|
74
75
|
|
75
76
|
|
76
77
|
class SamplerConfig(TypedDict):
|
@@ -215,6 +216,7 @@ class MRLTrainer:
|
|
215
216
|
self.callbacks = []
|
216
217
|
self.global_epoch = 0
|
217
218
|
self.global_epochs_count = 0
|
219
|
+
self.teacher_forcing = False
|
218
220
|
|
219
221
|
def _init_optimizers(
|
220
222
|
self,
|
@@ -452,8 +454,10 @@ class MRLTrainer:
|
|
452
454
|
|
453
455
|
# 11. Update STM with generated response (except last interaction, it's not needed)
|
454
456
|
if not is_last_interaction:
|
455
|
-
self.encode_and_update_stm(
|
456
|
-
|
457
|
+
self.encode_and_update_stm(
|
458
|
+
next_query,
|
459
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
460
|
+
) # update with generated_answer on GPU
|
457
461
|
|
458
462
|
# 12. Store trajectory step
|
459
463
|
trajectory: MrlTrajectoryStep = {
|
@@ -470,7 +474,7 @@ class MRLTrainer:
|
|
470
474
|
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
471
475
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
472
476
|
prev_interaction = (query, answer)
|
473
|
-
query, answer = interaction['query'], detached_answer
|
477
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
474
478
|
|
475
479
|
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
476
480
|
episode_trajectory: MrlTrajectoryEpisode = {
|
@@ -857,7 +861,10 @@ class MRLTrainer:
|
|
857
861
|
|
858
862
|
# 10. Encode and update memory for the next interaction
|
859
863
|
if not is_last_interaction:
|
860
|
-
self.encode_and_update_stm(
|
864
|
+
self.encode_and_update_stm(
|
865
|
+
next_query,
|
866
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
867
|
+
)
|
861
868
|
|
862
869
|
# 11. Accumulate rewards
|
863
870
|
step_reward = torch.tensor(reward).mean().to(self.device)
|
@@ -870,7 +877,7 @@ class MRLTrainer:
|
|
870
877
|
# 12. Save previous interaction
|
871
878
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
872
879
|
prev_interaction = (query, answer)
|
873
|
-
query, answer = interaction['query'], detached_answer
|
880
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
874
881
|
avg_episode_reward = (episode_reward / episode_interactions).item()
|
875
882
|
# 13. Run eval TensorBoard writer with average episode reward
|
876
883
|
self._eval_writer(avg_episode_reward, epoch)
|
@@ -1000,8 +1007,7 @@ class MRLTrainer:
|
|
1000
1007
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
1001
1008
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
1002
1009
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
1003
|
-
|
1004
|
-
|
1010
|
+
self.teacher_forcing = config.get('teacher_forcing', False)
|
1005
1011
|
|
1006
1012
|
def has_param(field: OptimField) -> bool:
|
1007
1013
|
return field in config and config[field] is not None
|
rxnn/training/reward.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Optional
|
6
7
|
from .utils import TokenizedDict
|
@@ -37,6 +38,7 @@ class MrlRewardModel:
|
|
37
38
|
reward_len: bool = False,
|
38
39
|
neg_reward_len: bool = False,
|
39
40
|
max_rewarded_len: int = None,
|
41
|
+
target_len_as_ref: bool = False,
|
40
42
|
len_factor: int = None,
|
41
43
|
use_running_mean: bool = True,
|
42
44
|
running_mean_decay: float = 0.2,
|
@@ -44,6 +46,7 @@ class MrlRewardModel:
|
|
44
46
|
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
47
|
tanh_reward_scale: bool = False,
|
46
48
|
rewards_scale: float = 1.0,
|
49
|
+
debug_mode: int = 0,
|
47
50
|
):
|
48
51
|
self.shared_embedding = shared_embedding.to(device)
|
49
52
|
self.device = device
|
@@ -67,6 +70,7 @@ class MrlRewardModel:
|
|
67
70
|
self.reward_len = reward_len
|
68
71
|
self.neg_reward_len = neg_reward_len
|
69
72
|
self.max_rewarded_len = max_rewarded_len
|
73
|
+
self.target_len_as_ref = target_len_as_ref
|
70
74
|
self.len_factor = len_factor
|
71
75
|
self.use_running_mean = use_running_mean
|
72
76
|
self.running_mean_decay = running_mean_decay
|
@@ -74,6 +78,8 @@ class MrlRewardModel:
|
|
74
78
|
self.bleu_saved_weights = bleu_saved_weights
|
75
79
|
self.tanh_reward_scale = tanh_reward_scale
|
76
80
|
self.rewards_scale = rewards_scale
|
81
|
+
self.bleu_smoothing = SmoothingFunction().method4
|
82
|
+
self.debug_mode = debug_mode
|
77
83
|
|
78
84
|
self.prev_data_running_mean = None
|
79
85
|
|
@@ -95,59 +101,133 @@ class MrlRewardModel:
|
|
95
101
|
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
96
102
|
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
97
103
|
|
98
|
-
def _sentence_bleu(self,
|
99
|
-
|
104
|
+
def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
105
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
106
|
+
generated, reference, saved_data = input_ids
|
107
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
108
|
+
|
109
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
110
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
111
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
112
|
+
|
113
|
+
if self.debug_mode == 2:
|
114
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
100
115
|
|
101
116
|
if self.bleu_with_saved_data:
|
102
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights
|
103
|
-
|
117
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
118
|
+
smoothing_function=self.bleu_smoothing)
|
119
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
120
|
+
smoothing_function=self.bleu_smoothing)
|
121
|
+
if self.debug_mode == 2:
|
122
|
+
print('REF BLEU: ', ref_bleu)
|
123
|
+
print('SAVED BLEU: ', saved_bleu)
|
124
|
+
|
104
125
|
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
105
126
|
else:
|
106
127
|
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
107
128
|
|
129
|
+
def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
130
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
131
|
+
generated, reference, saved_data = input_ids
|
132
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
133
|
+
|
134
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
135
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
136
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
108
137
|
|
109
|
-
|
110
|
-
|
111
|
-
from nltk.translate.bleu_score import sentence_bleu
|
138
|
+
if self.debug_mode == 2:
|
139
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
112
140
|
|
113
141
|
if self.bleu_with_saved_data:
|
114
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights
|
115
|
-
|
142
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
143
|
+
smoothing_function=self.bleu_smoothing)
|
144
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
145
|
+
smoothing_function=self.bleu_smoothing)
|
116
146
|
saved_bleu = 1 - saved_bleu
|
117
147
|
|
148
|
+
if self.debug_mode == 2:
|
149
|
+
print('REF BLEU: ', ref_bleu)
|
150
|
+
print('SAVED BLEU: ', saved_bleu)
|
151
|
+
|
118
152
|
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
119
153
|
else:
|
120
154
|
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
121
155
|
|
122
|
-
def batch_bleu(self, generated:
|
123
|
-
batch_size = generated.size(0)
|
124
|
-
|
156
|
+
def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
|
157
|
+
batch_size = generated['input_ids'].size(0)
|
158
|
+
|
159
|
+
return [
|
160
|
+
self._sentence_bleu(
|
161
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
162
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
163
|
+
) for i in range(batch_size)
|
164
|
+
]
|
165
|
+
|
166
|
+
def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
|
167
|
+
float]:
|
168
|
+
batch_size = generated['input_ids'].size(0)
|
169
|
+
|
170
|
+
return [
|
171
|
+
self._negative_sentence_bleu(
|
172
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
173
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
174
|
+
) for i in range(batch_size)
|
175
|
+
]
|
176
|
+
|
177
|
+
def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
|
178
|
+
input_ids = sequence['input_ids']
|
179
|
+
attention_mask = sequence['attention_mask']
|
180
|
+
|
181
|
+
# Get embeddings
|
182
|
+
embeddings = self.shared_embedding(input_ids.to(self.device))
|
183
|
+
|
184
|
+
# Apply attention mask
|
185
|
+
mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
|
186
|
+
masked_embeddings = embeddings * mask_expanded
|
187
|
+
|
188
|
+
# Compute mean with masking
|
189
|
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
190
|
+
token_counts = torch.sum(mask_expanded, dim=1)
|
191
|
+
token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
|
192
|
+
|
193
|
+
return sum_embeddings / token_counts
|
125
194
|
|
126
|
-
def
|
127
|
-
|
128
|
-
|
195
|
+
def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
196
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
197
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
198
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
129
199
|
|
130
|
-
|
131
|
-
|
200
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
201
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
132
202
|
|
133
|
-
|
134
|
-
|
135
|
-
|
203
|
+
if self.debug_mode >= 1:
|
204
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
205
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
206
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
|
136
207
|
|
137
|
-
def _cosine_sim_running_mean(self, generated:
|
138
|
-
generated_emb = self._sequence_embedding(generated)
|
208
|
+
def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
209
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
210
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
211
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
212
|
+
running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
|
139
213
|
|
140
|
-
gen_and_saved =
|
141
|
-
gen_and_ref =
|
142
|
-
gen_and_mean =
|
143
|
-
return gen_and_saved, gen_and_ref, gen_and_mean
|
214
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
215
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
216
|
+
gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
|
144
217
|
|
145
|
-
|
218
|
+
if self.debug_mode >= 1:
|
219
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
220
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
221
|
+
print('GEN AND MEAN: ', gen_and_mean.mean())
|
222
|
+
|
223
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
|
224
|
+
|
225
|
+
def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
|
146
226
|
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
147
227
|
if self.use_running_mean and negative_running_mean:
|
148
228
|
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
149
229
|
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
150
|
-
|
230
|
+
1 - gen_and_mean)
|
151
231
|
elif self.use_running_mean and include_running_mean:
|
152
232
|
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
153
233
|
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
@@ -155,25 +235,26 @@ class MrlRewardModel:
|
|
155
235
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
156
236
|
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
157
237
|
|
158
|
-
def negative_cosine(self, generated:
|
159
|
-
saved_data:
|
238
|
+
def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
|
239
|
+
saved_data: TokenizedDict) -> torch.Tensor:
|
160
240
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
161
241
|
|
162
242
|
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
163
243
|
|
164
|
-
def len_reward(self, generated: TokenizedDict):
|
165
|
-
|
166
|
-
|
167
|
-
|
244
|
+
def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
|
245
|
+
target_lens = reference['attention_mask'].to(self.device).sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
246
|
+
lens = generated['attention_mask'].to(self.device).sum(dim=1)
|
247
|
+
neg_lens = target_lens / lens if self.neg_reward_len else 1.0
|
248
|
+
len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
|
168
249
|
return len_reward
|
169
250
|
|
170
251
|
def reset_running_mean(self):
|
171
252
|
self.prev_data_running_mean = None
|
172
253
|
|
173
|
-
def init_running_mean(self, prev_data:
|
254
|
+
def init_running_mean(self, prev_data: TokenizedDict):
|
174
255
|
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
175
256
|
|
176
|
-
def update_running_mean(self, prev_data:
|
257
|
+
def update_running_mean(self, prev_data: TokenizedDict):
|
177
258
|
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
178
259
|
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
179
260
|
|
@@ -193,24 +274,50 @@ class MrlRewardModel:
|
|
193
274
|
) -> list[float]:
|
194
275
|
if prev_data is not None:
|
195
276
|
if self.prev_data_running_mean is None:
|
196
|
-
self.init_running_mean(prev_data
|
277
|
+
self.init_running_mean(prev_data)
|
197
278
|
else:
|
198
|
-
self.update_running_mean(prev_data
|
279
|
+
self.update_running_mean(prev_data)
|
199
280
|
|
200
281
|
if mode == MrlRewardMode.STANDARD:
|
201
|
-
bleu = self.batch_bleu(generated
|
202
|
-
cosine = self.batch_cosine(generated
|
203
|
-
|
282
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
283
|
+
cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
|
284
|
+
|
285
|
+
if self.debug_mode >= 1:
|
286
|
+
print('STANDARD MODE')
|
287
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
288
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
289
|
+
|
204
290
|
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
205
291
|
elif mode == MrlRewardMode.LONG_RANGE:
|
206
|
-
bleu = self.batch_bleu(generated
|
207
|
-
cosine = self.batch_cosine(generated
|
292
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
293
|
+
cosine = self.batch_cosine(generated, reference, saved_data,
|
208
294
|
negative_running_mean=prev_data is not None)
|
295
|
+
|
296
|
+
if self.debug_mode >= 1:
|
297
|
+
print('LONG MODE')
|
298
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
299
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
300
|
+
|
209
301
|
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
210
302
|
else:
|
211
|
-
bleu = self.
|
212
|
-
cosine = self.negative_cosine(generated
|
303
|
+
bleu = self.negative_bleu(generated, reference, saved_data)
|
304
|
+
cosine = self.negative_cosine(generated, reference, saved_data)
|
305
|
+
|
306
|
+
if self.debug_mode >= 1:
|
307
|
+
print('NEGATIVE MODE')
|
308
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
309
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
310
|
+
|
213
311
|
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
214
312
|
|
215
|
-
|
313
|
+
if self.reward_len:
|
314
|
+
len_reward = self.len_reward(generated, reference)
|
315
|
+
|
316
|
+
if self.debug_mode >= 1:
|
317
|
+
print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
|
318
|
+
|
319
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
|
320
|
+
else:
|
321
|
+
rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
|
322
|
+
|
216
323
|
return rewards.tolist()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.50
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
16
16
|
Requires-Dist: datasets (>=3.5.0,<4.0.0)
|
17
17
|
Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
|
18
|
+
Requires-Dist: nltk (>=3.9.1,<4.0.0)
|
18
19
|
Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
|
19
20
|
Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
|
20
21
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
@@ -17,8 +17,8 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=L2emJM06u7B9f9T1dFsGXzXX-rsV77ND7L1pAM9Z_Ow,9051
|
20
|
-
rxnn/training/mrl.py,sha256=
|
21
|
-
rxnn/training/reward.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=IOi_xbQ47RPgv_2ucT9EkPeWLGBRlgPxKHFeQsYc3Pw,61074
|
21
|
+
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.50.dist-info/METADATA,sha256=MmlWkWUki9ErQnJ24yP2R9mDykQewDHDcyCQhzopZAw,25997
|
38
|
+
rxnn-0.2.50.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.50.dist-info/RECORD,,
|
File without changes
|
File without changes
|