rxnn 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +2 -1
- rxnn/training/callbacks.py +1 -1
- rxnn/training/dataset.py +1 -1
- rxnn/training/mrl.py +3 -2
- rxnn/training/rl.py +9 -9
- {rxnn-0.2.10.dist-info → rxnn-0.2.12.dist-info}/METADATA +1 -1
- {rxnn-0.2.10.dist-info → rxnn-0.2.12.dist-info}/RECORD +9 -9
- {rxnn-0.2.10.dist-info → rxnn-0.2.12.dist-info}/LICENSE +0 -0
- {rxnn-0.2.10.dist-info → rxnn-0.2.12.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
@@ -42,7 +42,8 @@ class ShortTermMemory(nn.Module):
|
|
42
42
|
self.memory[layer] = new_stm
|
43
43
|
|
44
44
|
def update_all(self, new_stm: torch.Tensor):
|
45
|
-
self.memory
|
45
|
+
self.memory = new_stm
|
46
|
+
# self.memory.copy_(new_stm)
|
46
47
|
|
47
48
|
def make_trainable(self):
|
48
49
|
if not self.is_trainable:
|
rxnn/training/callbacks.py
CHANGED
@@ -577,7 +577,7 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
577
577
|
|
578
578
|
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
579
579
|
critic_loss: float) -> None:
|
580
|
-
print(f'Epoch {epoch} | Step {step} - updated
|
580
|
+
print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
|
581
581
|
|
582
582
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
583
|
print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
rxnn/training/dataset.py
CHANGED
@@ -936,7 +936,7 @@ class MrlCurriculumDataset(Dataset):
|
|
936
936
|
else:
|
937
937
|
subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
|
938
938
|
self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
|
939
|
-
return self.__class__(subset, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
|
939
|
+
return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
|
940
940
|
|
941
941
|
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
|
942
942
|
"""
|
rxnn/training/mrl.py
CHANGED
@@ -328,7 +328,8 @@ class MRLTrainer:
|
|
328
328
|
|
329
329
|
# 10. Update STM with generated response (except last interaction, it's not needed)
|
330
330
|
if not is_last_interaction:
|
331
|
-
self.encode_and_update_stm(next_query,
|
331
|
+
self.encode_and_update_stm(next_query,
|
332
|
+
generated_answer) # update with generated_answer on GPU
|
332
333
|
|
333
334
|
# 11. Store trajectory step
|
334
335
|
trajectory: MrlTrajectoryStep = {
|
@@ -438,7 +439,7 @@ class MRLTrainer:
|
|
438
439
|
critic_losses.append(critic_loss)
|
439
440
|
|
440
441
|
# 7. Calculate mean loss for epoch callbacks
|
441
|
-
critic_mean_loss = torch.
|
442
|
+
critic_mean_loss = torch.tensor(critic_losses).mean().item()
|
442
443
|
|
443
444
|
return critic_mean_loss
|
444
445
|
|
rxnn/training/rl.py
CHANGED
@@ -54,16 +54,16 @@ class PPOAlgorithm(RlAlgorithm):
|
|
54
54
|
|
55
55
|
return policy_loss
|
56
56
|
|
57
|
-
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
57
|
+
# def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
|
58
|
+
# advantages = torch.zeros_like(rewards, device=values.device)
|
59
|
+
# last_advantage = 0
|
60
|
+
# for t in reversed(range(rewards.size(0))):
|
61
|
+
# delta = rewards[t] + self.gae_gamma * next_value - values[t]
|
62
|
+
# advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
63
|
+
# last_advantage = advantages[t]
|
64
|
+
# return advantages
|
65
65
|
|
66
66
|
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
67
|
-
advantages =
|
67
|
+
advantages = rewards - values
|
68
68
|
normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
69
69
|
return normalized_advantages
|
@@ -7,18 +7,18 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=S5CtPI2KXxjs_vvMtb-w57ZPN3TmvVvU3TBHG2au2VE,3879
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
|
-
rxnn/training/callbacks.py,sha256
|
17
|
-
rxnn/training/dataset.py,sha256=
|
16
|
+
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
|
+
rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=mCsg50bX0iqPozvvQB6CeZ0FYEfuj9ln1p-4IaZBryo,39338
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=T69gLwDlvMMyLuRaJSRmwzO0Mcu0uLXwhAiBB58VK-Y,2663
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.12.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.12.dist-info/METADATA,sha256=HvEJSZUelxjiAKWAQ3wwbNtNmMsJjxlstZZModU9UMw,25960
|
37
|
+
rxnn-0.2.12.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|