ASAC-pytorch 0.0.1__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asac_pytorch-0.0.4/ASAC/ASAC.py +293 -0
- asac_pytorch-0.0.4/ASAC/__init__.py +1 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.4}/PKG-INFO +7 -3
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.4}/README.md +2 -2
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.4}/pyproject.toml +5 -1
- asac_pytorch-0.0.1/ASAC/ASAC.py +0 -0
- asac_pytorch-0.0.1/ASAC/__init__.py +0 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.4}/.gitignore +0 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.4}/LICENSE +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import namedtuple
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, tensor
|
|
6
|
+
from torch.nn import Module, Linear, ModuleList
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from einops import einsum, reduce
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
|
|
12
|
+
from x_transformers import Decoder
|
|
13
|
+
|
|
14
|
+
from x_mlps_pytorch import MLP
|
|
15
|
+
|
|
16
|
+
from vector_quantize_pytorch import VectorQuantize
|
|
17
|
+
|
|
18
|
+
from ema_pytorch import EMA
|
|
19
|
+
|
|
20
|
+
from torch_einops_utils import pack_with_inverse, maybe
|
|
21
|
+
|
|
22
|
+
# helpers
|
|
23
|
+
|
|
24
|
+
def exists(v):
|
|
25
|
+
return v is not None
|
|
26
|
+
|
|
27
|
+
def default(v, d):
|
|
28
|
+
return v if exists(v) else d
|
|
29
|
+
|
|
30
|
+
# return types
|
|
31
|
+
|
|
32
|
+
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
|
|
33
|
+
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
|
|
34
|
+
|
|
35
|
+
# feedforward
|
|
36
|
+
|
|
37
|
+
def FeedForward(dim, expansion_factor = 4.):
|
|
38
|
+
dim_inner = int(dim * expansion_factor)
|
|
39
|
+
return nn.Sequential(
|
|
40
|
+
nn.RMSNorm(dim),
|
|
41
|
+
nn.Linear(dim, dim_inner),
|
|
42
|
+
nn.GELU(),
|
|
43
|
+
nn.Linear(dim_inner, dim)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# embedding
|
|
47
|
+
|
|
48
|
+
def PatchEmbedding(dim, patch_size, channels = 3):
|
|
49
|
+
patch_dim = channels * (patch_size ** 2)
|
|
50
|
+
|
|
51
|
+
return nn.Sequential(
|
|
52
|
+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
|
53
|
+
nn.RMSNorm(patch_dim),
|
|
54
|
+
Linear(patch_dim, dim),
|
|
55
|
+
nn.RMSNorm(dim),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# attention
|
|
59
|
+
|
|
60
|
+
class Attention(Module):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
dim,
|
|
64
|
+
dim_head = 64,
|
|
65
|
+
heads = 8,
|
|
66
|
+
k_rmsnorm = True,
|
|
67
|
+
attn_schema: Module | None = None,
|
|
68
|
+
attn_add_residual = True # they had to add a residual for stability
|
|
69
|
+
):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.scale = dim_head ** -0.5
|
|
72
|
+
dim_inner = dim_head * heads
|
|
73
|
+
|
|
74
|
+
self.norm = nn.RMSNorm(dim)
|
|
75
|
+
|
|
76
|
+
self.to_qkv = Linear(dim, dim_inner * 3, bias = False)
|
|
77
|
+
self.combine_heads = Linear(dim_inner, dim, bias = False)
|
|
78
|
+
|
|
79
|
+
self.k_rmsnorm = nn.RMSNorm(dim_head) if k_rmsnorm else None
|
|
80
|
+
|
|
81
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
82
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
83
|
+
|
|
84
|
+
self.attn_schema = attn_schema
|
|
85
|
+
self.attn_add_residual = attn_add_residual and attn_schema
|
|
86
|
+
|
|
87
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
88
|
+
|
|
89
|
+
def forward(
|
|
90
|
+
self,
|
|
91
|
+
tokens, # (b h w d)
|
|
92
|
+
pre_softmax_attn_gates = None,
|
|
93
|
+
post_softmax_attn_gates = None
|
|
94
|
+
):
|
|
95
|
+
tokens = self.norm(tokens)
|
|
96
|
+
|
|
97
|
+
tokens, inverse_pack = pack_with_inverse(tokens, 'b * d')
|
|
98
|
+
|
|
99
|
+
q, k, v = self.to_qkv(tokens).chunk(3, dim = -1)
|
|
100
|
+
q, k, v = (self.split_heads(t) for t in (q, k, v))
|
|
101
|
+
|
|
102
|
+
k = maybe(self.k_rmsnorm)(k)
|
|
103
|
+
|
|
104
|
+
q = q * self.scale
|
|
105
|
+
|
|
106
|
+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
|
107
|
+
|
|
108
|
+
orig_sim = sim
|
|
109
|
+
|
|
110
|
+
# the proposal
|
|
111
|
+
|
|
112
|
+
aux_loss = self.zero
|
|
113
|
+
aux_loss_breakdown = (self.zero, self.zero)
|
|
114
|
+
indices = None
|
|
115
|
+
|
|
116
|
+
if exists(self.attn_schema):
|
|
117
|
+
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
|
|
118
|
+
|
|
119
|
+
if self.attn_add_residual:
|
|
120
|
+
sim = sim + orig_sim
|
|
121
|
+
|
|
122
|
+
# modulate
|
|
123
|
+
|
|
124
|
+
if exists(pre_softmax_attn_gates):
|
|
125
|
+
sim = sim + pre_softmax_attn_gates
|
|
126
|
+
|
|
127
|
+
# attend
|
|
128
|
+
|
|
129
|
+
attn = sim.softmax(dim = -1)
|
|
130
|
+
|
|
131
|
+
# modulate
|
|
132
|
+
|
|
133
|
+
if exists(post_softmax_attn_gates):
|
|
134
|
+
attn = attn * post_softmax_attn_gates
|
|
135
|
+
|
|
136
|
+
# aggregate and combine out
|
|
137
|
+
|
|
138
|
+
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
139
|
+
|
|
140
|
+
out = self.merge_heads(out)
|
|
141
|
+
attended = self.combine_heads(out)
|
|
142
|
+
|
|
143
|
+
# bring back the packed dimensions
|
|
144
|
+
|
|
145
|
+
attended = inverse_pack(attended)
|
|
146
|
+
|
|
147
|
+
return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown)
|
|
148
|
+
|
|
149
|
+
# attention autoencoder
|
|
150
|
+
|
|
151
|
+
class AttentionSchema(Module):
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
dim,
|
|
155
|
+
dim_bottleneck,
|
|
156
|
+
kl_div_loss = True,
|
|
157
|
+
detach_target = True,
|
|
158
|
+
encoder: Module | None = None,
|
|
159
|
+
decoder: Module | None = None,
|
|
160
|
+
recon_loss_weight = 1.,
|
|
161
|
+
commit_loss_weight = 1.,
|
|
162
|
+
**vq_kwargs
|
|
163
|
+
):
|
|
164
|
+
super().__init__()
|
|
165
|
+
|
|
166
|
+
if not exists(encoder):
|
|
167
|
+
encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
|
|
168
|
+
|
|
169
|
+
self.encoder = encoder
|
|
170
|
+
|
|
171
|
+
self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
|
|
172
|
+
|
|
173
|
+
if not exists(decoder):
|
|
174
|
+
decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
|
|
175
|
+
|
|
176
|
+
self.decoder = decoder
|
|
177
|
+
|
|
178
|
+
self.kl_div_loss = kl_div_loss
|
|
179
|
+
self.detach_target = detach_target
|
|
180
|
+
|
|
181
|
+
self.recon_loss_weight = recon_loss_weight
|
|
182
|
+
self.commit_loss_weight = commit_loss_weight
|
|
183
|
+
|
|
184
|
+
def forward(
|
|
185
|
+
self,
|
|
186
|
+
attn_sim,
|
|
187
|
+
return_loss = None
|
|
188
|
+
):
|
|
189
|
+
return_loss = default(return_loss, self.training)
|
|
190
|
+
|
|
191
|
+
attn_features, inverse_pack = pack_with_inverse(attn_sim, 'b *')
|
|
192
|
+
|
|
193
|
+
encoded = self.encoder(attn_features)
|
|
194
|
+
|
|
195
|
+
quantized, indices, commit_loss = self.vq(encoded)
|
|
196
|
+
|
|
197
|
+
decoded = self.decoder(quantized)
|
|
198
|
+
|
|
199
|
+
recon = inverse_pack(decoded)
|
|
200
|
+
|
|
201
|
+
# loss, mse as in paper or reverse kl
|
|
202
|
+
|
|
203
|
+
if return_loss:
|
|
204
|
+
if self.detach_target:
|
|
205
|
+
attn_sim = attn_sim.detach()
|
|
206
|
+
|
|
207
|
+
if self.kl_div_loss:
|
|
208
|
+
recon_loss = F.kl_div(
|
|
209
|
+
attn_sim.log_softmax(dim = -1),
|
|
210
|
+
recon.softmax(dim = -1),
|
|
211
|
+
reduction = 'batchmean'
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
recon_loss = F.mse_loss(recon, attn_sim)
|
|
215
|
+
|
|
216
|
+
# total
|
|
217
|
+
|
|
218
|
+
total_loss = recon_loss * self.recon_loss_weight + commit_loss * self.commit_loss_weight
|
|
219
|
+
|
|
220
|
+
return recon, indices, total_loss, (recon_loss, commit_loss)
|
|
221
|
+
|
|
222
|
+
# class
|
|
223
|
+
|
|
224
|
+
class ASAC(Module):
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
*,
|
|
228
|
+
dim,
|
|
229
|
+
depth,
|
|
230
|
+
heads,
|
|
231
|
+
to_embedding,
|
|
232
|
+
seq_len = None,
|
|
233
|
+
dim_head = 64,
|
|
234
|
+
num_classes = 10,
|
|
235
|
+
use_asac = False,
|
|
236
|
+
dim_bottleneck = 256,
|
|
237
|
+
vq_codebook_size = 256,
|
|
238
|
+
recon_loss_weight = 1.,
|
|
239
|
+
commit_loss_weight = 1.
|
|
240
|
+
):
|
|
241
|
+
super().__init__()
|
|
242
|
+
|
|
243
|
+
self.depth = depth
|
|
244
|
+
|
|
245
|
+
self.to_embedding = to_embedding
|
|
246
|
+
self.pos_embedding = nn.Parameter(torch.randn(seq_len, dim)) if exists(seq_len) else None
|
|
247
|
+
|
|
248
|
+
self.layers = ModuleList([])
|
|
249
|
+
|
|
250
|
+
for _ in range(depth):
|
|
251
|
+
attn_schema = AttentionSchema(
|
|
252
|
+
dim = heads * (seq_len ** 2),
|
|
253
|
+
dim_bottleneck = dim_bottleneck,
|
|
254
|
+
codebook_size = vq_codebook_size,
|
|
255
|
+
recon_loss_weight = recon_loss_weight,
|
|
256
|
+
commit_loss_weight = commit_loss_weight
|
|
257
|
+
) if use_asac and exists(seq_len) else None
|
|
258
|
+
|
|
259
|
+
self.layers.append(ModuleList([
|
|
260
|
+
Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
|
|
261
|
+
FeedForward(dim)
|
|
262
|
+
]))
|
|
263
|
+
|
|
264
|
+
self.to_logits = nn.Sequential(
|
|
265
|
+
nn.RMSNorm(dim),
|
|
266
|
+
Linear(dim, num_classes)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def forward(self, x):
|
|
270
|
+
x = self.to_embedding(x)
|
|
271
|
+
|
|
272
|
+
if exists(self.pos_embedding):
|
|
273
|
+
x = x + self.pos_embedding
|
|
274
|
+
|
|
275
|
+
total_aux_loss = 0.
|
|
276
|
+
total_recon_loss = 0.
|
|
277
|
+
total_commit_loss = 0.
|
|
278
|
+
|
|
279
|
+
for attn, ff in self.layers:
|
|
280
|
+
attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
|
|
281
|
+
|
|
282
|
+
x = attn_out + x
|
|
283
|
+
x = ff(x) + x
|
|
284
|
+
|
|
285
|
+
total_aux_loss = total_aux_loss + aux_loss
|
|
286
|
+
total_recon_loss = total_recon_loss + recon_loss
|
|
287
|
+
total_commit_loss = total_commit_loss + commit_loss
|
|
288
|
+
|
|
289
|
+
x = reduce(x, 'b n d -> b d', 'mean')
|
|
290
|
+
|
|
291
|
+
logits = self.to_logits(x)
|
|
292
|
+
|
|
293
|
+
return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth))
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ASAC.ASAC import ASAC, PatchEmbedding
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
@@ -36,8 +36,12 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.10
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
|
+
Requires-Dist: ema-pytorch
|
|
39
40
|
Requires-Dist: torch-einops-utils>=0.1.2
|
|
40
41
|
Requires-Dist: torch>=2.5
|
|
42
|
+
Requires-Dist: vector-quantize-pytorch
|
|
43
|
+
Requires-Dist: x-mlps-pytorch
|
|
44
|
+
Requires-Dist: x-transformers
|
|
41
45
|
Provides-Extra: examples
|
|
42
46
|
Provides-Extra: test
|
|
43
47
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -51,12 +55,12 @@ Implementation of [Attention Schema-based Attention Control (ASAC)](https://arxi
|
|
|
51
55
|
|
|
52
56
|
```bibtex
|
|
53
57
|
@misc{saxena2025attentionschemabasedattentioncontrol,
|
|
54
|
-
title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
|
|
58
|
+
title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
|
|
55
59
|
author = {Krati Saxena and Federico Jurado Ruiz and Guido Manzi and Dianbo Liu and Alex Lamb},
|
|
56
60
|
year = {2025},
|
|
57
61
|
eprint = {2509.16058},
|
|
58
62
|
archivePrefix = {arXiv},
|
|
59
63
|
primaryClass = {cs.AI},
|
|
60
|
-
url = {https://arxiv.org/abs/2509.16058},
|
|
64
|
+
url = {https://arxiv.org/abs/2509.16058},
|
|
61
65
|
}
|
|
62
66
|
```
|
|
@@ -6,12 +6,12 @@ Implementation of [Attention Schema-based Attention Control (ASAC)](https://arxi
|
|
|
6
6
|
|
|
7
7
|
```bibtex
|
|
8
8
|
@misc{saxena2025attentionschemabasedattentioncontrol,
|
|
9
|
-
title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
|
|
9
|
+
title = {Attention Schema-based Attention Control (ASAC): A Cognitive-Inspired Approach for Attention Management in Transformers},
|
|
10
10
|
author = {Krati Saxena and Federico Jurado Ruiz and Guido Manzi and Dianbo Liu and Alex Lamb},
|
|
11
11
|
year = {2025},
|
|
12
12
|
eprint = {2509.16058},
|
|
13
13
|
archivePrefix = {arXiv},
|
|
14
14
|
primaryClass = {cs.AI},
|
|
15
|
-
url = {https://arxiv.org/abs/2509.16058},
|
|
15
|
+
url = {https://arxiv.org/abs/2509.16058},
|
|
16
16
|
}
|
|
17
17
|
```
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "ASAC-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.4"
|
|
4
4
|
description = "Implementation of Attention Schema-based Attention Control (ASAC)"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -26,8 +26,12 @@ classifiers=[
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"einx>=0.3.0",
|
|
28
28
|
"einops>=0.8.1",
|
|
29
|
+
"ema-pytorch",
|
|
29
30
|
"torch>=2.5",
|
|
30
31
|
"torch-einops-utils>=0.1.2",
|
|
32
|
+
"vector-quantize-pytorch",
|
|
33
|
+
"x-transformers",
|
|
34
|
+
"x-mlps-pytorch"
|
|
31
35
|
]
|
|
32
36
|
|
|
33
37
|
[project.urls]
|
asac_pytorch-0.0.1/ASAC/ASAC.py
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|