rxnn 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -36,8 +36,7 @@ class StmMemoryAttention(nn.Module):
36
36
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
37
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
38
  # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
- final_layer_stm = new_layer_stm + layer_stm # residual
40
- self.stm.update_layer(i, final_layer_stm)
41
- # self.stm.update_all(new_stm)
39
+ new_stm[i] = new_layer_stm + layer_stm # residual
40
+ self.stm.update_all(new_stm)
42
41
  return self.stm.memory
43
42
 
rxnn/memory/stm.py CHANGED
@@ -39,10 +39,12 @@ class ShortTermMemory(nn.Module):
39
39
  return self.memory[layer]
40
40
 
41
41
  def update_layer(self, layer: int, new_stm: torch.Tensor):
42
+ self.memory = self.memory.clone()
42
43
  self.memory[layer] = new_stm
43
44
 
44
45
  def update_all(self, new_stm: torch.Tensor):
45
- self.memory.copy_(new_stm)
46
+ self.memory = new_stm
47
+ # self.memory.copy_(new_stm)
46
48
 
47
49
  def make_trainable(self):
48
50
  if not self.is_trainable:
@@ -59,7 +61,7 @@ class ShortTermMemory(nn.Module):
59
61
  self.register_buffer('memory', trained_stm)
60
62
 
61
63
  def reset(self, init_type: str = None):
62
- self.memory.copy_(self._init_tensor(init_type))
64
+ self.memory = self._init_tensor(init_type).to(self.memory.device)
63
65
 
64
66
  def resize(self, new_stm_size: int, init_type: str = None):
65
67
  self.stm_size = new_stm_size
@@ -84,8 +86,7 @@ class ShortTermMemory(nn.Module):
84
86
  if use_mean_from_batch:
85
87
  batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
86
88
  delattr(self, 'memory')
87
- self.register_buffer('memory', self._init_tensor())
88
- self.memory.copy_(batch_mean)
89
+ self.register_buffer('memory', batch_mean)
89
90
  else:
90
91
  delattr(self, 'memory')
91
92
  self.register_buffer('memory', self._init_tensor())
rxnn/training/mrl.py CHANGED
@@ -461,7 +461,6 @@ class MRLTrainer:
461
461
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
462
462
  # memory, based on collected episode data
463
463
  all_losses = []
464
- trajectories_len = len(trajectories)
465
464
  for episode_idx, episode in enumerate(trajectories):
466
465
  episode_steps = episode['steps']
467
466
  should_reset_stm = episode['reset_stm']
@@ -514,14 +513,14 @@ class MRLTrainer:
514
513
 
515
514
  # 9. Update the model in AMP or regular mode
516
515
  if self.use_amp:
517
- self.scaler.scale(policy_loss).backward()
516
+ self.scaler.scale(policy_loss).backward(retain_graph=True)
518
517
  self.scaler.unscale_(self.optimizer)
519
518
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
520
519
  error_if_nonfinite=False)
521
520
  self.scaler.step(self.optimizer)
522
521
  self.scaler.update()
523
522
  else:
524
- policy_loss.backward()
523
+ policy_loss.backward(retain_graph=True)
525
524
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
526
525
  error_if_nonfinite=False)
527
526
  self.optimizer.step()
rxnn/training/rl.py CHANGED
@@ -43,6 +43,8 @@ class PPOAlgorithm(RlAlgorithm):
43
43
  # b) Calculate ratio
44
44
  ratio = (new_log_probs - old_log_probs).exp()
45
45
 
46
+ advantages = advantages.unsqueeze(-1)
47
+
46
48
  # c) Clipped surrogate loss
47
49
  surr1 = ratio * advantages
48
50
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,9 +5,9 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=eCsTjJEQVurBjsIlJyk1cDvOGU2YMbYMKiBVOb3mfKg,1874
8
+ rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
10
+ rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,9 +16,9 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=mCsg50bX0iqPozvvQB6CeZ0FYEfuj9ln1p-4IaZBryo,39338
19
+ rxnn/training/mrl.py,sha256=53uOwotmgwKeceMYA6qXQbQMZmggXt_5hq08X-YwrEY,39327
20
20
  rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
- rxnn/training/rl.py,sha256=T69gLwDlvMMyLuRaJSRmwzO0Mcu0uLXwhAiBB58VK-Y,2663
21
+ rxnn/training/rl.py,sha256=s6wPbg0X6y-RX9-5ctZIDpdJPfExI9DzWUy-TvAiiow,2710
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.11.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.11.dist-info/METADATA,sha256=Latj91qV2ruPOzt-bdaqfGfu0WF42e5qPz6MCBVVtTo,25960
37
- rxnn-0.2.11.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.11.dist-info/RECORD,,
35
+ rxnn-0.2.13.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.13.dist-info/METADATA,sha256=BOn4qig3IKpYiG0NEWHiF_5NWsWboBqVNeGb2-mYesU,25960
37
+ rxnn-0.2.13.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.13.dist-info/RECORD,,
File without changes
File without changes