rxnn 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +0 -1
- rxnn/memory/norm.py +11 -13
- {rxnn-0.2.24.dist-info → rxnn-0.2.25.dist-info}/METADATA +1 -1
- {rxnn-0.2.24.dist-info → rxnn-0.2.25.dist-info}/RECORD +6 -6
- {rxnn-0.2.24.dist-info → rxnn-0.2.25.dist-info}/LICENSE +0 -0
- {rxnn-0.2.24.dist-info → rxnn-0.2.25.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -35,7 +35,6 @@ class StmMemoryAttention(nn.Module):
|
|
35
35
|
encoded_layer_data = x[i]
|
36
36
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
37
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
|
-
# self.stm.update_layer(i, new_layer_stm + layer_stm)
|
39
38
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
40
39
|
self.stm.update_all(new_stm)
|
41
40
|
return self.stm.memory
|
rxnn/memory/norm.py
CHANGED
@@ -20,8 +20,8 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
20
20
|
self.eps = 1e-6
|
21
21
|
|
22
22
|
# Learnable parameters
|
23
|
-
self.scale = nn.Parameter(torch.ones(num_slots,
|
24
|
-
self.gate = nn.Parameter(torch.full((num_slots, 1
|
23
|
+
self.scale = nn.Parameter(torch.ones(num_slots, dim)) if use_scale else None
|
24
|
+
self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
|
25
25
|
|
26
26
|
# EMA buffers
|
27
27
|
self.register_buffer("ema_rms", torch.ones(num_slots, 1))
|
@@ -31,28 +31,26 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
31
31
|
nn.init.normal_(self.scale, mean=1.0, std=0.01)
|
32
32
|
|
33
33
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
34
|
-
# x shape: [batch_size, num_slots, dim]
|
35
|
-
batch_size = x.size(0)
|
36
|
-
|
37
34
|
# Calculate current RMS per slot
|
38
|
-
|
39
|
-
|
35
|
+
# x: [batch_size, num_slots, dim]
|
36
|
+
current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, num_slots, 1]
|
37
|
+
slot_rms = current_rms.mean(dim=0) # [num_slots, 1] (average over batch)
|
40
38
|
|
41
39
|
# Update EMA during training
|
42
40
|
if self.training:
|
43
|
-
self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
|
41
|
+
self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach() # [num_slots, 1]
|
44
42
|
|
45
43
|
# Normalize using EMA statistics
|
46
|
-
x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
|
44
|
+
x_norm = x * torch.rsqrt(self.ema_rms + self.eps) # [batch_size, num_slots, dim] * [num_slots, 1]
|
47
45
|
|
48
46
|
# Apply learned scale per slot
|
49
47
|
if self.scale is not None:
|
50
|
-
x_norm = x_norm * self.scale
|
48
|
+
x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, dim]
|
51
49
|
|
52
50
|
# Apply gating mechanism
|
53
51
|
if self.use_gate:
|
54
|
-
gate = torch.sigmoid(self.gate) # [
|
55
|
-
return gate * x_norm + (1 - gate) * x
|
52
|
+
gate = torch.sigmoid(self.gate) # [num_slots, 1]
|
53
|
+
return gate * x_norm + (1 - gate) * x # [batch_size, num_slots, dim] * [num_slots, 1]
|
56
54
|
|
57
55
|
return x_norm
|
58
56
|
|
@@ -77,7 +75,7 @@ class AdaptiveRMSMemoryNorm(nn.Module):
|
|
77
75
|
# x shape: [batch_size, num_slots, dim]
|
78
76
|
if self.training and hasattr(self, 'ema_rms'):
|
79
77
|
# Compute current RMS across all slots and batch (scalar)
|
80
|
-
current_rms = x.pow(2).mean(
|
78
|
+
current_rms = x.pow(2).mean(dim=-1).mean().sqrt()
|
81
79
|
self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
|
82
80
|
rms = self.ema_rms
|
83
81
|
else:
|
@@ -5,8 +5,8 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
9
|
-
rxnn/memory/norm.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
|
9
|
+
rxnn/memory/norm.py,sha256=mu_6iZJe61ag627csfJN2JK6QmmzofjOEhxV4ZWblXs,6410
|
10
10
|
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.25.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.25.dist-info/METADATA,sha256=nuGFk4oqSMhn6vrw2KZs4RtY0_ZLowg29IlkNVHZ6Jo,25960
|
37
|
+
rxnn-0.2.25.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|