rxnn 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -35,7 +35,6 @@ class StmMemoryAttention(nn.Module):
35
35
  encoded_layer_data = x[i]
36
36
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
37
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
- # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
38
  new_stm[i] = new_layer_stm + layer_stm # residual
40
39
  self.stm.update_all(new_stm)
41
40
  return self.stm.memory
rxnn/memory/norm.py CHANGED
@@ -20,8 +20,8 @@ class AdaptivePositionalMemoryNorm(nn.Module):
20
20
  self.eps = 1e-6
21
21
 
22
22
  # Learnable parameters
23
- self.scale = nn.Parameter(torch.ones(num_slots, 1, dim)) if use_scale else None
24
- self.gate = nn.Parameter(torch.full((num_slots, 1, 1), init_gate)) if use_gate else None
23
+ self.scale = nn.Parameter(torch.ones(num_slots, dim)) if use_scale else None
24
+ self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
25
25
 
26
26
  # EMA buffers
27
27
  self.register_buffer("ema_rms", torch.ones(num_slots, 1))
@@ -31,28 +31,26 @@ class AdaptivePositionalMemoryNorm(nn.Module):
31
31
  nn.init.normal_(self.scale, mean=1.0, std=0.01)
32
32
 
33
33
  def forward(self, x: torch.Tensor) -> torch.Tensor:
34
- # x shape: [batch_size, num_slots, dim]
35
- batch_size = x.size(0)
36
-
37
34
  # Calculate current RMS per slot
38
- current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, slots, 1]
39
- slot_rms = current_rms.mean(dim=0) # [slots, 1] (average over batch)
35
+ # x: [batch_size, num_slots, dim]
36
+ current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, num_slots, 1]
37
+ slot_rms = current_rms.mean(dim=0) # [num_slots, 1] (average over batch)
40
38
 
41
39
  # Update EMA during training
42
40
  if self.training:
43
- self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
41
+ self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach() # [num_slots, 1]
44
42
 
45
43
  # Normalize using EMA statistics
46
- x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
44
+ x_norm = x * torch.rsqrt(self.ema_rms + self.eps) # [batch_size, num_slots, dim] * [num_slots, 1]
47
45
 
48
46
  # Apply learned scale per slot
49
47
  if self.scale is not None:
50
- x_norm = x_norm * self.scale
48
+ x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, dim]
51
49
 
52
50
  # Apply gating mechanism
53
51
  if self.use_gate:
54
- gate = torch.sigmoid(self.gate) # [slots, 1, 1]
55
- return gate * x_norm + (1 - gate) * x
52
+ gate = torch.sigmoid(self.gate) # [num_slots, 1]
53
+ return gate * x_norm + (1 - gate) * x # [batch_size, num_slots, dim] * [num_slots, 1]
56
54
 
57
55
  return x_norm
58
56
 
@@ -77,7 +75,7 @@ class AdaptiveRMSMemoryNorm(nn.Module):
77
75
  # x shape: [batch_size, num_slots, dim]
78
76
  if self.training and hasattr(self, 'ema_rms'):
79
77
  # Compute current RMS across all slots and batch (scalar)
80
- current_rms = x.pow(2).mean(-1).mean().sqrt()
78
+ current_rms = x.pow(2).mean(dim=-1).mean().sqrt()
81
79
  self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
82
80
  rms = self.ema_rms
83
81
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.24
3
+ Version: 0.2.25
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,8 +5,8 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
- rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
8
+ rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
9
+ rxnn/memory/norm.py,sha256=mu_6iZJe61ag627csfJN2JK6QmmzofjOEhxV4ZWblXs,6410
10
10
  rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.24.dist-info/METADATA,sha256=PrVfcCd8NBFtFnD8lAJqU7UW3lLEc-Tr7MQhK6obvuo,25960
37
- rxnn-0.2.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.24.dist-info/RECORD,,
35
+ rxnn-0.2.25.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.25.dist-info/METADATA,sha256=nuGFk4oqSMhn6vrw2KZs4RtY0_ZLowg29IlkNVHZ6Jo,25960
37
+ rxnn-0.2.25.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.25.dist-info/RECORD,,
File without changes
File without changes