rxnn 0.2.36__tar.gz → 0.2.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.36 → rxnn-0.2.38}/PKG-INFO +1 -1
  2. {rxnn-0.2.36 → rxnn-0.2.38}/pyproject.toml +1 -1
  3. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/base.py +2 -1
  4. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/bml.py +8 -0
  5. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/rl.py +7 -8
  6. {rxnn-0.2.36 → rxnn-0.2.38}/LICENSE +0 -0
  7. {rxnn-0.2.36 → rxnn-0.2.38}/README.md +0 -0
  8. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/rxt/models.py +0 -0
  20. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.36 → rxnn-0.2.38}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.36
3
+ Version: 0.2.38
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.36"
7
+ version = "0.2.38"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -82,6 +82,7 @@ class BaseTrainer(ABC):
82
82
  dataset: torch.utils.data.Dataset = None,
83
83
  optimizer: torch.optim.Optimizer = None,
84
84
  scheduler: torch.optim.lr_scheduler.LRScheduler = None,
85
+ ddp_find_unused_parameters: bool = False,
85
86
  ) -> None:
86
87
  self.is_running = True
87
88
  if dataset is None:
@@ -94,7 +95,7 @@ class BaseTrainer(ABC):
94
95
  if self.use_ddp:
95
96
  rank, world_size = get_os_ddp_config()
96
97
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
97
- self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
98
+ self.model = DistributedDataParallel(self.model, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
98
99
  train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
99
100
  dataloader = torch.utils.data.DataLoader(
100
101
  dataset,
@@ -51,6 +51,10 @@ class MLMTrainer(BaseTrainer):
51
51
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
52
52
 
53
53
  router_loss = model.encoder.model.moe_router_loss()
54
+
55
+ if self.use_ddp:
56
+ router_loss = distributed_mean(router_loss)
57
+
54
58
  loss = main_loss + self.moe_aux_loss_scale * router_loss
55
59
 
56
60
  if self.writer is not None:
@@ -152,6 +156,10 @@ class AutoregressiveTrainer(BaseTrainer):
152
156
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
153
157
 
154
158
  router_loss = model.model.moe_router_loss()
159
+
160
+ if self.use_ddp:
161
+ router_loss = distributed_mean(router_loss)
162
+
155
163
  loss = main_loss + self.moe_aux_loss_scale * router_loss
156
164
 
157
165
  if self.writer is not None:
@@ -112,21 +112,20 @@ class PPOAlgorithm(RlAlgorithm):
112
112
  advantages = torch.zeros_like(rewards, device=rewards.device)
113
113
  last_advantage = 0
114
114
  next_value = last_value
115
- next_done = last_done
115
+ next_done = last_done.float()
116
116
  dones = dones.float()
117
117
 
118
118
  for t in reversed(range(trajectory_len)):
119
- if t == trajectory_len - 1:
120
- # For the last step, use the provided last_value
121
- delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
122
- else:
123
- # For other steps, use the next value in the trajectory
124
- delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
125
-
119
+ # Calculate delta from rewards, stored next_value, masked by stored next_done, and values
120
+ delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
121
+ # Calculate advantages based on delta, gamma/lambda factors and last advantage, masked by current done flags
126
122
  advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
123
+ # Store current step data as last_advantage, next_done and next_value, for the next iteration step
127
124
  last_advantage = advantages[t]
128
125
  next_done = dones[t]
126
+ next_value = values[t]
129
127
 
128
+ # Calculate reference returns, based on advantages and values, and return them with advantages for critic update
130
129
  returns = advantages + values
131
130
  return advantages, returns
132
131
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes