ASAC-pytorch 0.0.7__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/ASAC/ASAC.py +119 -19
- asac_pytorch-0.0.10/ASAC/__init__.py +1 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/PKG-INFO +1 -1
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/pyproject.toml +1 -1
- asac_pytorch-0.0.7/ASAC/__init__.py +0 -1
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/.gitignore +0 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/LICENSE +0 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.10}/README.md +0 -0
|
@@ -6,10 +6,10 @@ from torch import nn, tensor
|
|
|
6
6
|
from torch.nn import Module, Linear, ModuleList
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
|
|
9
|
-
from einops import einsum, reduce
|
|
9
|
+
from einops import einsum, reduce, rearrange
|
|
10
10
|
from einops.layers.torch import Rearrange
|
|
11
11
|
|
|
12
|
-
from x_transformers import Decoder
|
|
12
|
+
from x_transformers import Decoder, AutoregressiveWrapper, TransformerWrapper
|
|
13
13
|
|
|
14
14
|
from x_mlps_pytorch import MLP
|
|
15
15
|
|
|
@@ -29,8 +29,8 @@ def default(v, d):
|
|
|
29
29
|
|
|
30
30
|
# return types
|
|
31
31
|
|
|
32
|
-
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
|
|
33
|
-
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
|
|
32
|
+
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'attn_sim'])
|
|
33
|
+
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'attn_sims', 'attn_schema_indices', 'attn_schema_autoregressive_loss'])
|
|
34
34
|
|
|
35
35
|
# feedforward
|
|
36
36
|
|
|
@@ -65,7 +65,8 @@ class Attention(Module):
|
|
|
65
65
|
heads = 8,
|
|
66
66
|
k_rmsnorm = True,
|
|
67
67
|
attn_schema: Module | None = None,
|
|
68
|
-
attn_add_residual = True # they had to add a residual for stability
|
|
68
|
+
attn_add_residual = True, # they had to add a residual for stability
|
|
69
|
+
stochastic_sample_attn = False
|
|
69
70
|
):
|
|
70
71
|
super().__init__()
|
|
71
72
|
self.scale = dim_head ** -0.5
|
|
@@ -84,13 +85,16 @@ class Attention(Module):
|
|
|
84
85
|
self.attn_schema = attn_schema
|
|
85
86
|
self.attn_add_residual = attn_add_residual and attn_schema
|
|
86
87
|
|
|
88
|
+
self.stochastic_sample_attn = stochastic_sample_attn
|
|
89
|
+
|
|
87
90
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
88
91
|
|
|
89
92
|
def forward(
|
|
90
93
|
self,
|
|
91
94
|
tokens, # (b h w d)
|
|
92
95
|
pre_softmax_attn_gates = None,
|
|
93
|
-
post_softmax_attn_gates = None
|
|
96
|
+
post_softmax_attn_gates = None,
|
|
97
|
+
attn_schema_target = None
|
|
94
98
|
):
|
|
95
99
|
tokens = self.norm(tokens)
|
|
96
100
|
|
|
@@ -114,7 +118,7 @@ class Attention(Module):
|
|
|
114
118
|
indices = None
|
|
115
119
|
|
|
116
120
|
if exists(self.attn_schema):
|
|
117
|
-
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
|
|
121
|
+
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim, target_sim = attn_schema_target)
|
|
118
122
|
|
|
119
123
|
if self.attn_add_residual:
|
|
120
124
|
sim = (sim + orig_sim) * 0.5
|
|
@@ -126,7 +130,10 @@ class Attention(Module):
|
|
|
126
130
|
|
|
127
131
|
# attend
|
|
128
132
|
|
|
129
|
-
|
|
133
|
+
if self.stochastic_sample_attn:
|
|
134
|
+
attn = F.gumbel_softmax(sim, tau = 1., hard = True, dim = -1)
|
|
135
|
+
else:
|
|
136
|
+
attn = sim.softmax(dim = -1)
|
|
130
137
|
|
|
131
138
|
# modulate
|
|
132
139
|
|
|
@@ -144,7 +151,13 @@ class Attention(Module):
|
|
|
144
151
|
|
|
145
152
|
attended = inverse_pack(attended)
|
|
146
153
|
|
|
147
|
-
return AttentionReturn(
|
|
154
|
+
return AttentionReturn(
|
|
155
|
+
attended,
|
|
156
|
+
indices,
|
|
157
|
+
aux_loss,
|
|
158
|
+
aux_loss_breakdown,
|
|
159
|
+
orig_sim
|
|
160
|
+
)
|
|
148
161
|
|
|
149
162
|
# attention autoencoder
|
|
150
163
|
|
|
@@ -186,7 +199,8 @@ class AttentionSchema(Module):
|
|
|
186
199
|
def forward(
|
|
187
200
|
self,
|
|
188
201
|
attn_sim,
|
|
189
|
-
return_loss = None
|
|
202
|
+
return_loss = None,
|
|
203
|
+
target_sim = None
|
|
190
204
|
):
|
|
191
205
|
return_loss = default(return_loss, self.training)
|
|
192
206
|
|
|
@@ -204,18 +218,20 @@ class AttentionSchema(Module):
|
|
|
204
218
|
|
|
205
219
|
recon_loss = self.zero
|
|
206
220
|
|
|
221
|
+
target = default(target_sim, attn_sim)
|
|
222
|
+
|
|
207
223
|
if return_loss:
|
|
208
224
|
if self.detach_target:
|
|
209
|
-
|
|
225
|
+
target = target.detach()
|
|
210
226
|
|
|
211
227
|
if self.kl_div_loss:
|
|
212
228
|
recon_loss = F.kl_div(
|
|
213
|
-
|
|
229
|
+
target.log_softmax(dim = -1),
|
|
214
230
|
recon.softmax(dim = -1),
|
|
215
231
|
reduction = 'none'
|
|
216
232
|
).sum(dim = -1).mean()
|
|
217
233
|
else:
|
|
218
|
-
recon_loss = F.mse_loss(recon,
|
|
234
|
+
recon_loss = F.mse_loss(recon, target)
|
|
219
235
|
|
|
220
236
|
# total
|
|
221
237
|
|
|
@@ -241,10 +257,15 @@ class ASAC(Module):
|
|
|
241
257
|
vq_codebook_size = 256,
|
|
242
258
|
recon_loss_weight = 1.,
|
|
243
259
|
commit_loss_weight = 1.,
|
|
244
|
-
kl_div_loss = True
|
|
260
|
+
kl_div_loss = True,
|
|
261
|
+
stochastic_sample_attn = False,
|
|
262
|
+
awareness_model_depth = 2,
|
|
263
|
+
**awareness_model_kwargs
|
|
245
264
|
):
|
|
246
265
|
super().__init__()
|
|
247
266
|
|
|
267
|
+
assert depth >= 2, 'depth must be at least 2'
|
|
268
|
+
|
|
248
269
|
self.depth = depth
|
|
249
270
|
|
|
250
271
|
self.to_embedding = to_embedding
|
|
@@ -263,16 +284,40 @@ class ASAC(Module):
|
|
|
263
284
|
) if use_asac and exists(seq_len) else None
|
|
264
285
|
|
|
265
286
|
self.layers.append(ModuleList([
|
|
266
|
-
Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
|
|
287
|
+
Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema, stochastic_sample_attn = stochastic_sample_attn),
|
|
267
288
|
FeedForward(dim)
|
|
268
289
|
]))
|
|
269
290
|
|
|
291
|
+
# autoregressive awareness model (attention schema theory)
|
|
292
|
+
|
|
293
|
+
self.awareness_transformer = None
|
|
294
|
+
self.awareness_model = None
|
|
295
|
+
|
|
296
|
+
if use_asac and exists(seq_len):
|
|
297
|
+
self.awareness_transformer = TransformerWrapper(
|
|
298
|
+
num_tokens = vq_codebook_size,
|
|
299
|
+
max_seq_len = depth,
|
|
300
|
+
attn_layers = Decoder(
|
|
301
|
+
dim = dim,
|
|
302
|
+
depth = awareness_model_depth,
|
|
303
|
+
heads = heads,
|
|
304
|
+
rotary_pos_emb = True,
|
|
305
|
+
**awareness_model_kwargs
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
self.awareness_model = AutoregressiveWrapper(self.awareness_transformer)
|
|
310
|
+
|
|
311
|
+
# zero buffer for auxiliary losses
|
|
312
|
+
|
|
313
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
314
|
+
|
|
270
315
|
self.to_logits = nn.Sequential(
|
|
271
316
|
nn.RMSNorm(dim),
|
|
272
317
|
Linear(dim, num_classes)
|
|
273
318
|
)
|
|
274
319
|
|
|
275
|
-
def forward(self, x):
|
|
320
|
+
def forward(self, x, attn_schema_targets = None):
|
|
276
321
|
x = self.to_embedding(x)
|
|
277
322
|
|
|
278
323
|
if exists(self.pos_embedding):
|
|
@@ -280,8 +325,16 @@ class ASAC(Module):
|
|
|
280
325
|
|
|
281
326
|
total_aux_loss = total_recon_loss = total_commit_loss = 0.
|
|
282
327
|
|
|
283
|
-
|
|
284
|
-
|
|
328
|
+
attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
|
|
329
|
+
attn_sims = []
|
|
330
|
+
attn_schema_indices = []
|
|
331
|
+
|
|
332
|
+
for (attn, ff), target in zip(self.layers, attn_schema_targets):
|
|
333
|
+
attn_out, indices, aux_loss, (recon_loss, commit_loss), attn_sim = attn(x, attn_schema_target = target)
|
|
334
|
+
|
|
335
|
+
attn_sims.append(attn_sim)
|
|
336
|
+
if exists(indices):
|
|
337
|
+
attn_schema_indices.append(indices)
|
|
285
338
|
|
|
286
339
|
x = attn_out + x
|
|
287
340
|
x = ff(x) + x
|
|
@@ -294,4 +347,51 @@ class ASAC(Module):
|
|
|
294
347
|
|
|
295
348
|
logits = self.to_logits(x)
|
|
296
349
|
|
|
297
|
-
|
|
350
|
+
attn_schema_autoregressive_loss = self.zero
|
|
351
|
+
|
|
352
|
+
if attn_schema_indices:
|
|
353
|
+
attn_schema_indices = rearrange(attn_schema_indices, 'depth b ... -> b (depth ...)')
|
|
354
|
+
|
|
355
|
+
if exists(self.awareness_model):
|
|
356
|
+
attn_schema_autoregressive_loss = self.awareness_model(attn_schema_indices)
|
|
357
|
+
else:
|
|
358
|
+
attn_schema_indices = None
|
|
359
|
+
|
|
360
|
+
return ASACReturn(
|
|
361
|
+
logits,
|
|
362
|
+
total_aux_loss,
|
|
363
|
+
(total_recon_loss / self.depth, total_commit_loss / self.depth),
|
|
364
|
+
attn_sims,
|
|
365
|
+
attn_schema_indices,
|
|
366
|
+
attn_schema_autoregressive_loss
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
class EMA_ASAC(Module):
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
asac_model,
|
|
373
|
+
ema_decay = 0.999,
|
|
374
|
+
**ema_kwargs
|
|
375
|
+
):
|
|
376
|
+
super().__init__()
|
|
377
|
+
self.asac = asac_model
|
|
378
|
+
self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
|
|
379
|
+
|
|
380
|
+
def update(self):
|
|
381
|
+
self.ema_model.update()
|
|
382
|
+
|
|
383
|
+
def forward(self, x, use_ema = False):
|
|
384
|
+
if use_ema:
|
|
385
|
+
return self.ema_model(x)
|
|
386
|
+
|
|
387
|
+
if not self.training:
|
|
388
|
+
return self.asac(x)
|
|
389
|
+
|
|
390
|
+
# get EMA targets
|
|
391
|
+
with torch.no_grad():
|
|
392
|
+
self.ema_model.eval()
|
|
393
|
+
ema_outputs = self.ema_model(x)
|
|
394
|
+
ema_targets = [sim.detach() for sim in ema_outputs.attn_sims]
|
|
395
|
+
self.ema_model.train()
|
|
396
|
+
|
|
397
|
+
return self.asac(x, attn_schema_targets = ema_targets)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ASAC.ASAC import ASAC, PatchEmbedding, EMA_ASAC
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from ASAC.ASAC import ASAC, PatchEmbedding
|
|
File without changes
|
|
File without changes
|
|
File without changes
|