ASAC-pytorch 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,293 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import namedtuple
4
+ import torch
5
+ from torch import nn, tensor
6
+ from torch.nn import Module, Linear, ModuleList
7
+ import torch.nn.functional as F
8
+
9
+ from einops import einsum, reduce
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from x_transformers import Decoder
13
+
14
+ from x_mlps_pytorch import MLP
15
+
16
+ from vector_quantize_pytorch import VectorQuantize
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from torch_einops_utils import pack_with_inverse, maybe
21
+
22
+ # helpers
23
+
24
+ def exists(v):
25
+ return v is not None
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+ # return types
31
+
32
+ AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
33
+ ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
34
+
35
+ # feedforward
36
+
37
+ def FeedForward(dim, expansion_factor = 4.):
38
+ dim_inner = int(dim * expansion_factor)
39
+ return nn.Sequential(
40
+ nn.RMSNorm(dim),
41
+ nn.Linear(dim, dim_inner),
42
+ nn.GELU(),
43
+ nn.Linear(dim_inner, dim)
44
+ )
45
+
46
+ # embedding
47
+
48
+ def PatchEmbedding(dim, patch_size, channels = 3):
49
+ patch_dim = channels * (patch_size ** 2)
50
+
51
+ return nn.Sequential(
52
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
53
+ nn.RMSNorm(patch_dim),
54
+ Linear(patch_dim, dim),
55
+ nn.RMSNorm(dim),
56
+ )
57
+
58
+ # attention
59
+
60
+ class Attention(Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ dim_head = 64,
65
+ heads = 8,
66
+ k_rmsnorm = True,
67
+ attn_schema: Module | None = None,
68
+ attn_add_residual = True # they had to add a residual for stability
69
+ ):
70
+ super().__init__()
71
+ self.scale = dim_head ** -0.5
72
+ dim_inner = dim_head * heads
73
+
74
+ self.norm = nn.RMSNorm(dim)
75
+
76
+ self.to_qkv = Linear(dim, dim_inner * 3, bias = False)
77
+ self.combine_heads = Linear(dim_inner, dim, bias = False)
78
+
79
+ self.k_rmsnorm = nn.RMSNorm(dim_head) if k_rmsnorm else None
80
+
81
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
82
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
83
+
84
+ self.attn_schema = attn_schema
85
+ self.attn_add_residual = attn_add_residual and attn_schema
86
+
87
+ self.register_buffer('zero', tensor(0.), persistent = False)
88
+
89
+ def forward(
90
+ self,
91
+ tokens, # (b h w d)
92
+ pre_softmax_attn_gates = None,
93
+ post_softmax_attn_gates = None
94
+ ):
95
+ tokens = self.norm(tokens)
96
+
97
+ tokens, inverse_pack = pack_with_inverse(tokens, 'b * d')
98
+
99
+ q, k, v = self.to_qkv(tokens).chunk(3, dim = -1)
100
+ q, k, v = (self.split_heads(t) for t in (q, k, v))
101
+
102
+ k = maybe(self.k_rmsnorm)(k)
103
+
104
+ q = q * self.scale
105
+
106
+ sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
107
+
108
+ orig_sim = sim
109
+
110
+ # the proposal
111
+
112
+ aux_loss = self.zero
113
+ aux_loss_breakdown = (self.zero, self.zero)
114
+ indices = None
115
+
116
+ if exists(self.attn_schema):
117
+ sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
118
+
119
+ if self.attn_add_residual:
120
+ sim = sim + orig_sim
121
+
122
+ # modulate
123
+
124
+ if exists(pre_softmax_attn_gates):
125
+ sim = sim + pre_softmax_attn_gates
126
+
127
+ # attend
128
+
129
+ attn = sim.softmax(dim = -1)
130
+
131
+ # modulate
132
+
133
+ if exists(post_softmax_attn_gates):
134
+ attn = attn * post_softmax_attn_gates
135
+
136
+ # aggregate and combine out
137
+
138
+ out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
139
+
140
+ out = self.merge_heads(out)
141
+ attended = self.combine_heads(out)
142
+
143
+ # bring back the packed dimensions
144
+
145
+ attended = inverse_pack(attended)
146
+
147
+ return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown)
148
+
149
+ # attention autoencoder
150
+
151
+ class AttentionSchema(Module):
152
+ def __init__(
153
+ self,
154
+ dim,
155
+ dim_bottleneck,
156
+ kl_div_loss = True,
157
+ detach_target = True,
158
+ encoder: Module | None = None,
159
+ decoder: Module | None = None,
160
+ recon_loss_weight = 1.,
161
+ commit_loss_weight = 1.,
162
+ **vq_kwargs
163
+ ):
164
+ super().__init__()
165
+
166
+ if not exists(encoder):
167
+ encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
168
+
169
+ self.encoder = encoder
170
+
171
+ self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
172
+
173
+ if not exists(decoder):
174
+ decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
175
+
176
+ self.decoder = decoder
177
+
178
+ self.kl_div_loss = kl_div_loss
179
+ self.detach_target = detach_target
180
+
181
+ self.recon_loss_weight = recon_loss_weight
182
+ self.commit_loss_weight = commit_loss_weight
183
+
184
+ def forward(
185
+ self,
186
+ attn_sim,
187
+ return_loss = None
188
+ ):
189
+ return_loss = default(return_loss, self.training)
190
+
191
+ attn_features, inverse_pack = pack_with_inverse(attn_sim, 'b *')
192
+
193
+ encoded = self.encoder(attn_features)
194
+
195
+ quantized, indices, commit_loss = self.vq(encoded)
196
+
197
+ decoded = self.decoder(quantized)
198
+
199
+ recon = inverse_pack(decoded)
200
+
201
+ # loss, mse as in paper or reverse kl
202
+
203
+ if return_loss:
204
+ if self.detach_target:
205
+ attn_sim = attn_sim.detach()
206
+
207
+ if self.kl_div_loss:
208
+ recon_loss = F.kl_div(
209
+ attn_sim.log_softmax(dim = -1),
210
+ recon.softmax(dim = -1),
211
+ reduction = 'batchmean'
212
+ )
213
+ else:
214
+ recon_loss = F.mse_loss(recon, attn_sim)
215
+
216
+ # total
217
+
218
+ total_loss = recon_loss * self.recon_loss_weight + commit_loss * self.commit_loss_weight
219
+
220
+ return recon, indices, total_loss, (recon_loss, commit_loss)
221
+
222
+ # class
223
+
224
+ class ASAC(Module):
225
+ def __init__(
226
+ self,
227
+ *,
228
+ dim,
229
+ depth,
230
+ heads,
231
+ to_embedding,
232
+ seq_len = None,
233
+ dim_head = 64,
234
+ num_classes = 10,
235
+ use_asac = False,
236
+ dim_bottleneck = 256,
237
+ vq_codebook_size = 256,
238
+ recon_loss_weight = 1.,
239
+ commit_loss_weight = 1.
240
+ ):
241
+ super().__init__()
242
+
243
+ self.depth = depth
244
+
245
+ self.to_embedding = to_embedding
246
+ self.pos_embedding = nn.Parameter(torch.randn(seq_len, dim)) if exists(seq_len) else None
247
+
248
+ self.layers = ModuleList([])
249
+
250
+ for _ in range(depth):
251
+ attn_schema = AttentionSchema(
252
+ dim = heads * (seq_len ** 2),
253
+ dim_bottleneck = dim_bottleneck,
254
+ codebook_size = vq_codebook_size,
255
+ recon_loss_weight = recon_loss_weight,
256
+ commit_loss_weight = commit_loss_weight
257
+ ) if use_asac and exists(seq_len) else None
258
+
259
+ self.layers.append(ModuleList([
260
+ Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
261
+ FeedForward(dim)
262
+ ]))
263
+
264
+ self.to_logits = nn.Sequential(
265
+ nn.RMSNorm(dim),
266
+ Linear(dim, num_classes)
267
+ )
268
+
269
+ def forward(self, x):
270
+ x = self.to_embedding(x)
271
+
272
+ if exists(self.pos_embedding):
273
+ x = x + self.pos_embedding
274
+
275
+ total_aux_loss = 0.
276
+ total_recon_loss = 0.
277
+ total_commit_loss = 0.
278
+
279
+ for attn, ff in self.layers:
280
+ attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
281
+
282
+ x = attn_out + x
283
+ x = ff(x) + x
284
+
285
+ total_aux_loss = total_aux_loss + aux_loss
286
+ total_recon_loss = total_recon_loss + recon_loss
287
+ total_commit_loss = total_commit_loss + commit_loss
288
+
289
+ x = reduce(x, 'b n d -> b d', 'mean')
290
+
291
+ logits = self.to_logits(x)
292
+
293
+ return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth))
@@ -0,0 +1 @@
1
+ from ASAC.ASAC import ASAC, PatchEmbedding
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -55,12 +55,12 @@ Implementation of [Attention Schema-based Attention Control (ASAC)](https://arxi
55
55
 
56
56
  ```bibtex
57
57
  @misc{saxena2025attentionschemabasedattentioncontrol,
58
- title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
58
+ title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
59
59
  author = {Krati Saxena and Federico Jurado Ruiz and Guido Manzi and Dianbo Liu and Alex Lamb},
60
60
  year = {2025},
61
61
  eprint = {2509.16058},
62
62
  archivePrefix = {arXiv},
63
63
  primaryClass = {cs.AI},
64
- url = {https://arxiv.org/abs/2509.16058},
64
+ url = {https://arxiv.org/abs/2509.16058},
65
65
  }
66
66
  ```
@@ -6,12 +6,12 @@ Implementation of [Attention Schema-based Attention Control (ASAC)](https://arxi
6
6
 
7
7
  ```bibtex
8
8
  @misc{saxena2025attentionschemabasedattentioncontrol,
9
- title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
9
+ title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
10
10
  author = {Krati Saxena and Federico Jurado Ruiz and Guido Manzi and Dianbo Liu and Alex Lamb},
11
11
  year = {2025},
12
12
  eprint = {2509.16058},
13
13
  archivePrefix = {arXiv},
14
14
  primaryClass = {cs.AI},
15
- url = {https://arxiv.org/abs/2509.16058},
15
+ url = {https://arxiv.org/abs/2509.16058},
16
16
  }
17
17
  ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.2"
3
+ version = "0.0.4"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,165 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from torch import nn, tensor
5
- from torch.nn import Module, Linear
6
- import torch.nn.functional as F
7
-
8
- from einops import einsum
9
- from einops.layers.torch import Rearrange
10
-
11
- from x_transformers import Decoder
12
-
13
- from x_mlps_pytorch import MLP
14
-
15
- from vector_quantize_pytorch import VectorQuantize
16
-
17
- from ema_pytorch import EMA
18
-
19
- from torch_einops_utils import pack_with_inverse
20
-
21
- # helpers
22
-
23
- def exists(v):
24
- return v is not None
25
-
26
- def default(v, d):
27
- return v if exists(v) else d
28
-
29
- # attention
30
-
31
- class Attention(Module):
32
- def __init__(
33
- self,
34
- dim,
35
- dim_head = 64,
36
- heads = 8,
37
- attn_schema: Module | None = None,
38
- attn_add_residual = True # they had to add a residual for stability
39
- ):
40
- super().__init__()
41
- self.scale = dim_head ** -0.5
42
- dim_inner = dim_head * heads
43
-
44
- self.to_qkv = Linear(dim, dim_inner * 3, bias = False)
45
- self.combine_heads = Linear(dim_inner, dim, bias = False)
46
-
47
- self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
48
- self.merge_heads = Rearrange('b h n d -> b n (h d)')
49
-
50
- self.attn_schema = attn_schema
51
- self.attn_add_residual = attn_add_residual and attn_schema
52
-
53
- self.register_buffer('zero', tensor(0.), persistent = False)
54
-
55
- def forward(
56
- self,
57
- tokens, # (b h w d)
58
- ):
59
- tokens, inverse_pack = pack_with_inverse(tokens, 'b * d')
60
-
61
- q, k, v = self.to_qkv(tokens).chunk(3, dim = -1)
62
- q, k, v = (self.split_heads(t) for t in (q, k, v))
63
-
64
- q = q * self.scale
65
-
66
- sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
67
-
68
- orig_sim = sim
69
-
70
- # the proposal
71
-
72
- aux_loss = self.zero
73
-
74
- if exists(self.attn_schema):
75
- sim, indices, aux_loss = self.attn_schema(orig_sim)
76
-
77
- if self.attn_add_residual:
78
- sim = sim + orig_sim
79
-
80
- # attend
81
-
82
- attn = sim.softmax(dim = -1)
83
-
84
- # aggregate and combine out
85
-
86
- out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
87
-
88
- out = self.merge_heads(out)
89
- attended = self.combine_heads(out)
90
-
91
- # bring back the packed dimensions
92
-
93
- attended = inverse_pack(attended)
94
-
95
- return attended, indices, aux_loss
96
-
97
- # attention autoencoder
98
-
99
- class AttentionSchema(Module):
100
- def __init__(
101
- self,
102
- dim,
103
- dim_bottleneck,
104
- kl_div_loss = True,
105
- detach_target = True,
106
- **vq_kwargs
107
- ):
108
- super().__init__()
109
- self.encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
110
-
111
- self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
112
-
113
- self.decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
114
-
115
- self.kl_div_loss = kl_div_loss
116
- self.detach_target = detach_target
117
-
118
- def forward(
119
- self,
120
- attn_sim,
121
- return_loss = None
122
- ):
123
- return_loss = default(return_loss, self.training)
124
-
125
- attn_features, inverse_pack = pack_with_inverse(attn_sim, 'b *')
126
-
127
- encoded = self.encoder(attn_features)
128
-
129
- quantized, indices, commit_loss = self.vq(encoded)
130
-
131
- decoded = self.decoder(quantized)
132
-
133
- recon = inverse_pack(decoded)
134
-
135
- # loss, mse as in paper or reverse kl
136
-
137
- if return_loss:
138
- if self.detach_target:
139
- attn_sim = attn_sim.detach()
140
-
141
- if self.kl_div_loss:
142
- recon_loss = F.kl_div(
143
- attn_sim.log_softmax(dim = -1),
144
- recon,
145
- reduction = 'batchmean'
146
- )
147
- else:
148
- recon_loss = F.mse_loss(recon, attn_sim)
149
-
150
- # total
151
-
152
- total_loss = recon_loss + commit_loss
153
-
154
- loss_breakdown = (recon_loss, commit_loss)
155
-
156
- return recon, indices, total_loss
157
-
158
- # class
159
-
160
- class ASAC(Module):
161
- def __init__(self):
162
- super().__init__()
163
-
164
- def forward(self, x):
165
- return x
@@ -1 +0,0 @@
1
- from ASAC.ASAC import ASAC
File without changes
File without changes