rxnn 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +2 -3
- rxnn/memory/stm.py +5 -4
- rxnn/training/mrl.py +2 -3
- rxnn/training/rl.py +2 -0
- {rxnn-0.2.11.dist-info → rxnn-0.2.13.dist-info}/METADATA +1 -1
- {rxnn-0.2.11.dist-info → rxnn-0.2.13.dist-info}/RECORD +8 -8
- {rxnn-0.2.11.dist-info → rxnn-0.2.13.dist-info}/LICENSE +0 -0
- {rxnn-0.2.11.dist-info → rxnn-0.2.13.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -36,8 +36,7 @@ class StmMemoryAttention(nn.Module):
|
|
36
36
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
37
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
38
|
# self.stm.update_layer(i, new_layer_stm + layer_stm)
|
39
|
-
|
40
|
-
|
41
|
-
# self.stm.update_all(new_stm)
|
39
|
+
new_stm[i] = new_layer_stm + layer_stm # residual
|
40
|
+
self.stm.update_all(new_stm)
|
42
41
|
return self.stm.memory
|
43
42
|
|
rxnn/memory/stm.py
CHANGED
@@ -39,10 +39,12 @@ class ShortTermMemory(nn.Module):
|
|
39
39
|
return self.memory[layer]
|
40
40
|
|
41
41
|
def update_layer(self, layer: int, new_stm: torch.Tensor):
|
42
|
+
self.memory = self.memory.clone()
|
42
43
|
self.memory[layer] = new_stm
|
43
44
|
|
44
45
|
def update_all(self, new_stm: torch.Tensor):
|
45
|
-
self.memory
|
46
|
+
self.memory = new_stm
|
47
|
+
# self.memory.copy_(new_stm)
|
46
48
|
|
47
49
|
def make_trainable(self):
|
48
50
|
if not self.is_trainable:
|
@@ -59,7 +61,7 @@ class ShortTermMemory(nn.Module):
|
|
59
61
|
self.register_buffer('memory', trained_stm)
|
60
62
|
|
61
63
|
def reset(self, init_type: str = None):
|
62
|
-
self.memory
|
64
|
+
self.memory = self._init_tensor(init_type).to(self.memory.device)
|
63
65
|
|
64
66
|
def resize(self, new_stm_size: int, init_type: str = None):
|
65
67
|
self.stm_size = new_stm_size
|
@@ -84,8 +86,7 @@ class ShortTermMemory(nn.Module):
|
|
84
86
|
if use_mean_from_batch:
|
85
87
|
batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
|
86
88
|
delattr(self, 'memory')
|
87
|
-
self.register_buffer('memory',
|
88
|
-
self.memory.copy_(batch_mean)
|
89
|
+
self.register_buffer('memory', batch_mean)
|
89
90
|
else:
|
90
91
|
delattr(self, 'memory')
|
91
92
|
self.register_buffer('memory', self._init_tensor())
|
rxnn/training/mrl.py
CHANGED
@@ -461,7 +461,6 @@ class MRLTrainer:
|
|
461
461
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
462
462
|
# memory, based on collected episode data
|
463
463
|
all_losses = []
|
464
|
-
trajectories_len = len(trajectories)
|
465
464
|
for episode_idx, episode in enumerate(trajectories):
|
466
465
|
episode_steps = episode['steps']
|
467
466
|
should_reset_stm = episode['reset_stm']
|
@@ -514,14 +513,14 @@ class MRLTrainer:
|
|
514
513
|
|
515
514
|
# 9. Update the model in AMP or regular mode
|
516
515
|
if self.use_amp:
|
517
|
-
self.scaler.scale(policy_loss).backward()
|
516
|
+
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
518
517
|
self.scaler.unscale_(self.optimizer)
|
519
518
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
520
519
|
error_if_nonfinite=False)
|
521
520
|
self.scaler.step(self.optimizer)
|
522
521
|
self.scaler.update()
|
523
522
|
else:
|
524
|
-
policy_loss.backward()
|
523
|
+
policy_loss.backward(retain_graph=True)
|
525
524
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
526
525
|
error_if_nonfinite=False)
|
527
526
|
self.optimizer.step()
|
rxnn/training/rl.py
CHANGED
@@ -43,6 +43,8 @@ class PPOAlgorithm(RlAlgorithm):
|
|
43
43
|
# b) Calculate ratio
|
44
44
|
ratio = (new_log_probs - old_log_probs).exp()
|
45
45
|
|
46
|
+
advantages = advantages.unsqueeze(-1)
|
47
|
+
|
46
48
|
# c) Clipped surrogate loss
|
47
49
|
surr1 = ratio * advantages
|
48
50
|
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
@@ -5,9 +5,9 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,9 +16,9 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=53uOwotmgwKeceMYA6qXQbQMZmggXt_5hq08X-YwrEY,39327
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=s6wPbg0X6y-RX9-5ctZIDpdJPfExI9DzWUy-TvAiiow,2710
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.13.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.13.dist-info/METADATA,sha256=BOn4qig3IKpYiG0NEWHiF_5NWsWboBqVNeGb2-mYesU,25960
|
37
|
+
rxnn-0.2.13.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|