rxnn 0.2.24__tar.gz → 0.2.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {rxnn-0.2.24 → rxnn-0.2.25}/PKG-INFO +1 -1
  2. {rxnn-0.2.24 → rxnn-0.2.25}/pyproject.toml +1 -1
  3. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/memory/attention.py +0 -1
  4. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/memory/norm.py +11 -13
  5. {rxnn-0.2.24 → rxnn-0.2.25}/LICENSE +0 -0
  6. {rxnn-0.2.24 → rxnn-0.2.25}/README.md +0 -0
  7. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/.DS_Store +0 -0
  8. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/__init__.py +0 -0
  9. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/experimental/__init__.py +0 -0
  10. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/experimental/attention.py +0 -0
  11. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/experimental/models.py +0 -0
  12. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/experimental/moe.py +0 -0
  13. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/memory/__init__.py +0 -0
  14. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/dataset.py +0 -0
  22. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/models.py +0 -0
  23. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/mrl.py +0 -0
  24. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/reward.py +0 -0
  25. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/rl.py +0 -0
  26. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/scheduler.py +0 -0
  27. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/tokenizer.py +0 -0
  28. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/training/utils.py +0 -0
  29. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/__init__.py +0 -0
  30. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/attention.py +0 -0
  31. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/ff.py +0 -0
  32. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/layers.py +0 -0
  33. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/mask.py +0 -0
  34. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/models.py +0 -0
  35. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/moe.py +0 -0
  36. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/positional.py +0 -0
  37. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/transformers/sampler.py +0 -0
  38. {rxnn-0.2.24 → rxnn-0.2.25}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.24
3
+ Version: 0.2.25
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.24"
7
+ version = "0.2.25"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -35,7 +35,6 @@ class StmMemoryAttention(nn.Module):
35
35
  encoded_layer_data = x[i]
36
36
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
37
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
- # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
38
  new_stm[i] = new_layer_stm + layer_stm # residual
40
39
  self.stm.update_all(new_stm)
41
40
  return self.stm.memory
@@ -20,8 +20,8 @@ class AdaptivePositionalMemoryNorm(nn.Module):
20
20
  self.eps = 1e-6
21
21
 
22
22
  # Learnable parameters
23
- self.scale = nn.Parameter(torch.ones(num_slots, 1, dim)) if use_scale else None
24
- self.gate = nn.Parameter(torch.full((num_slots, 1, 1), init_gate)) if use_gate else None
23
+ self.scale = nn.Parameter(torch.ones(num_slots, dim)) if use_scale else None
24
+ self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
25
25
 
26
26
  # EMA buffers
27
27
  self.register_buffer("ema_rms", torch.ones(num_slots, 1))
@@ -31,28 +31,26 @@ class AdaptivePositionalMemoryNorm(nn.Module):
31
31
  nn.init.normal_(self.scale, mean=1.0, std=0.01)
32
32
 
33
33
  def forward(self, x: torch.Tensor) -> torch.Tensor:
34
- # x shape: [batch_size, num_slots, dim]
35
- batch_size = x.size(0)
36
-
37
34
  # Calculate current RMS per slot
38
- current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, slots, 1]
39
- slot_rms = current_rms.mean(dim=0) # [slots, 1] (average over batch)
35
+ # x: [batch_size, num_slots, dim]
36
+ current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, num_slots, 1]
37
+ slot_rms = current_rms.mean(dim=0) # [num_slots, 1] (average over batch)
40
38
 
41
39
  # Update EMA during training
42
40
  if self.training:
43
- self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
41
+ self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach() # [num_slots, 1]
44
42
 
45
43
  # Normalize using EMA statistics
46
- x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
44
+ x_norm = x * torch.rsqrt(self.ema_rms + self.eps) # [batch_size, num_slots, dim] * [num_slots, 1]
47
45
 
48
46
  # Apply learned scale per slot
49
47
  if self.scale is not None:
50
- x_norm = x_norm * self.scale
48
+ x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, dim]
51
49
 
52
50
  # Apply gating mechanism
53
51
  if self.use_gate:
54
- gate = torch.sigmoid(self.gate) # [slots, 1, 1]
55
- return gate * x_norm + (1 - gate) * x
52
+ gate = torch.sigmoid(self.gate) # [num_slots, 1]
53
+ return gate * x_norm + (1 - gate) * x # [batch_size, num_slots, dim] * [num_slots, 1]
56
54
 
57
55
  return x_norm
58
56
 
@@ -77,7 +75,7 @@ class AdaptiveRMSMemoryNorm(nn.Module):
77
75
  # x shape: [batch_size, num_slots, dim]
78
76
  if self.training and hasattr(self, 'ema_rms'):
79
77
  # Compute current RMS across all slots and batch (scalar)
80
- current_rms = x.pow(2).mean(-1).mean().sqrt()
78
+ current_rms = x.pow(2).mean(dim=-1).mean().sqrt()
81
79
  self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
82
80
  rms = self.ema_rms
83
81
  else:
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes