rxnn 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -36,7 +36,8 @@ class StmMemoryAttention(nn.Module):
36
36
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
37
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
38
  # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
- new_stm[i] = new_layer_stm + layer_stm # residual
40
- self.stm.update_all(new_stm)
39
+ final_layer_stm = new_layer_stm + layer_stm # residual
40
+ self.stm.update_layer(i, final_layer_stm)
41
+ # self.stm.update_all(new_stm)
41
42
  return self.stm.memory
42
43
 
@@ -577,7 +577,7 @@ class MrlPrintCallback(MrlTrainerCallback):
577
577
 
578
578
  def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
579
579
  critic_loss: float) -> None:
580
- print(f'Epoch {epoch} | Step {step} - updated policy loss {critic_loss}')
580
+ print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
581
581
 
582
582
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
583
  print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
rxnn/training/dataset.py CHANGED
@@ -936,7 +936,7 @@ class MrlCurriculumDataset(Dataset):
936
936
  else:
937
937
  subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
938
938
  self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
939
- return self.__class__(subset, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
939
+ return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
940
940
 
941
941
  def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
942
942
  """
rxnn/training/mrl.py CHANGED
@@ -328,7 +328,8 @@ class MRLTrainer:
328
328
 
329
329
  # 10. Update STM with generated response (except last interaction, it's not needed)
330
330
  if not is_last_interaction:
331
- self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
331
+ self.encode_and_update_stm(next_query,
332
+ generated_answer) # update with generated_answer on GPU
332
333
 
333
334
  # 11. Store trajectory step
334
335
  trajectory: MrlTrajectoryStep = {
@@ -438,7 +439,7 @@ class MRLTrainer:
438
439
  critic_losses.append(critic_loss)
439
440
 
440
441
  # 7. Calculate mean loss for epoch callbacks
441
- critic_mean_loss = torch.stack(critic_losses).mean().item()
442
+ critic_mean_loss = torch.tensor(critic_losses).mean().item()
442
443
 
443
444
  return critic_mean_loss
444
445
 
rxnn/training/reward.py CHANGED
@@ -103,9 +103,9 @@ class MrlRewardModel:
103
103
  if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
104
104
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
105
105
  cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
106
- return (self.bleu_factor * torch.tensor(bleu) + self.cos_factor * cosine).tolist()
106
+ return (self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine).tolist()
107
107
  else:
108
108
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
109
109
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
110
- return (self.neg_bleu_factor * torch.tensor(bleu) + self.neg_cos_factor * cosine).tolist()
110
+ return (self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine).tolist()
111
111
 
rxnn/training/rl.py CHANGED
@@ -54,16 +54,16 @@ class PPOAlgorithm(RlAlgorithm):
54
54
 
55
55
  return policy_loss
56
56
 
57
- def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
58
- advantages = torch.zeros_like(rewards, device=values.device)
59
- last_advantage = 0
60
- for t in reversed(range(rewards.size(0))):
61
- delta = rewards[t] + self.gae_gamma * next_value - values[t]
62
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
63
- last_advantage = advantages[t]
64
- return advantages
57
+ # def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
58
+ # advantages = torch.zeros_like(rewards, device=values.device)
59
+ # last_advantage = 0
60
+ # for t in reversed(range(rewards.size(0))):
61
+ # delta = rewards[t] + self.gae_gamma * next_value - values[t]
62
+ # advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
63
+ # last_advantage = advantages[t]
64
+ # return advantages
65
65
 
66
66
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
67
- advantages = self._compute_gae(rewards, values[:-1], values[-1])
67
+ advantages = rewards - values
68
68
  normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
69
69
  return normalized_advantages
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,7 +5,7 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
8
+ rxnn/memory/attention.py,sha256=eCsTjJEQVurBjsIlJyk1cDvOGU2YMbYMKiBVOb3mfKg,1874
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
10
  rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -13,12 +13,12 @@ rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
- rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
- rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
16
+ rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
+ rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=cftSa6zS3jfremxX0SQrxtbhbEHfZ-nvZT3Hl6GlgpI,39282
20
- rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
21
- rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
19
+ rxnn/training/mrl.py,sha256=mCsg50bX0iqPozvvQB6CeZ0FYEfuj9ln1p-4IaZBryo,39338
20
+ rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
+ rxnn/training/rl.py,sha256=T69gLwDlvMMyLuRaJSRmwzO0Mcu0uLXwhAiBB58VK-Y,2663
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.9.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.9.dist-info/METADATA,sha256=ERPLS1G1D0zDUR3OUpTAPUHQ47Xjd3U4b4bSMnqA4p4,25959
37
- rxnn-0.2.9.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.9.dist-info/RECORD,,
35
+ rxnn-0.2.11.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.11.dist-info/METADATA,sha256=Latj91qV2ruPOzt-bdaqfGfu0WF42e5qPz6MCBVVtTo,25960
37
+ rxnn-0.2.11.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.11.dist-info/RECORD,,
File without changes
File without changes