ASAC-pytorch 0.0.7__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/ASAC/ASAC.py +47 -13
- asac_pytorch-0.0.9/ASAC/__init__.py +1 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/PKG-INFO +1 -1
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/pyproject.toml +1 -1
- asac_pytorch-0.0.7/ASAC/__init__.py +0 -1
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/.gitignore +0 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/LICENSE +0 -0
- {asac_pytorch-0.0.7 → asac_pytorch-0.0.9}/README.md +0 -0
|
@@ -29,8 +29,8 @@ def default(v, d):
|
|
|
29
29
|
|
|
30
30
|
# return types
|
|
31
31
|
|
|
32
|
-
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
|
|
33
|
-
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
|
|
32
|
+
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'dot_sim'])
|
|
33
|
+
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'dot_sims'])
|
|
34
34
|
|
|
35
35
|
# feedforward
|
|
36
36
|
|
|
@@ -90,7 +90,8 @@ class Attention(Module):
|
|
|
90
90
|
self,
|
|
91
91
|
tokens, # (b h w d)
|
|
92
92
|
pre_softmax_attn_gates = None,
|
|
93
|
-
post_softmax_attn_gates = None
|
|
93
|
+
post_softmax_attn_gates = None,
|
|
94
|
+
attn_schema_target = None
|
|
94
95
|
):
|
|
95
96
|
tokens = self.norm(tokens)
|
|
96
97
|
|
|
@@ -114,7 +115,7 @@ class Attention(Module):
|
|
|
114
115
|
indices = None
|
|
115
116
|
|
|
116
117
|
if exists(self.attn_schema):
|
|
117
|
-
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
|
|
118
|
+
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim, target_sim = attn_schema_target)
|
|
118
119
|
|
|
119
120
|
if self.attn_add_residual:
|
|
120
121
|
sim = (sim + orig_sim) * 0.5
|
|
@@ -144,7 +145,7 @@ class Attention(Module):
|
|
|
144
145
|
|
|
145
146
|
attended = inverse_pack(attended)
|
|
146
147
|
|
|
147
|
-
return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown)
|
|
148
|
+
return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown, orig_sim)
|
|
148
149
|
|
|
149
150
|
# attention autoencoder
|
|
150
151
|
|
|
@@ -186,7 +187,8 @@ class AttentionSchema(Module):
|
|
|
186
187
|
def forward(
|
|
187
188
|
self,
|
|
188
189
|
attn_sim,
|
|
189
|
-
return_loss = None
|
|
190
|
+
return_loss = None,
|
|
191
|
+
target_sim = None
|
|
190
192
|
):
|
|
191
193
|
return_loss = default(return_loss, self.training)
|
|
192
194
|
|
|
@@ -204,18 +206,20 @@ class AttentionSchema(Module):
|
|
|
204
206
|
|
|
205
207
|
recon_loss = self.zero
|
|
206
208
|
|
|
209
|
+
target = default(target_sim, attn_sim)
|
|
210
|
+
|
|
207
211
|
if return_loss:
|
|
208
212
|
if self.detach_target:
|
|
209
|
-
|
|
213
|
+
target = target.detach()
|
|
210
214
|
|
|
211
215
|
if self.kl_div_loss:
|
|
212
216
|
recon_loss = F.kl_div(
|
|
213
|
-
|
|
217
|
+
target.log_softmax(dim = -1),
|
|
214
218
|
recon.softmax(dim = -1),
|
|
215
219
|
reduction = 'none'
|
|
216
220
|
).sum(dim = -1).mean()
|
|
217
221
|
else:
|
|
218
|
-
recon_loss = F.mse_loss(recon,
|
|
222
|
+
recon_loss = F.mse_loss(recon, target)
|
|
219
223
|
|
|
220
224
|
# total
|
|
221
225
|
|
|
@@ -272,7 +276,7 @@ class ASAC(Module):
|
|
|
272
276
|
Linear(dim, num_classes)
|
|
273
277
|
)
|
|
274
278
|
|
|
275
|
-
def forward(self, x):
|
|
279
|
+
def forward(self, x, attn_schema_targets = None):
|
|
276
280
|
x = self.to_embedding(x)
|
|
277
281
|
|
|
278
282
|
if exists(self.pos_embedding):
|
|
@@ -280,8 +284,13 @@ class ASAC(Module):
|
|
|
280
284
|
|
|
281
285
|
total_aux_loss = total_recon_loss = total_commit_loss = 0.
|
|
282
286
|
|
|
283
|
-
|
|
284
|
-
|
|
287
|
+
attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
|
|
288
|
+
dot_sims = []
|
|
289
|
+
|
|
290
|
+
for (attn, ff), target in zip(self.layers, attn_schema_targets):
|
|
291
|
+
attn_out, indices, aux_loss, (recon_loss, commit_loss), dot_sim = attn(x, attn_schema_target = target)
|
|
292
|
+
|
|
293
|
+
dot_sims.append(dot_sim)
|
|
285
294
|
|
|
286
295
|
x = attn_out + x
|
|
287
296
|
x = ff(x) + x
|
|
@@ -294,4 +303,29 @@ class ASAC(Module):
|
|
|
294
303
|
|
|
295
304
|
logits = self.to_logits(x)
|
|
296
305
|
|
|
297
|
-
return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth))
|
|
306
|
+
return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth), dot_sims)
|
|
307
|
+
|
|
308
|
+
class EMA_ASAC(Module):
|
|
309
|
+
def __init__(self, asac_model, ema_decay = 0.999, **ema_kwargs):
|
|
310
|
+
super().__init__()
|
|
311
|
+
self.asac = asac_model
|
|
312
|
+
self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
|
|
313
|
+
|
|
314
|
+
def forward(self, x, use_ema = False):
|
|
315
|
+
if use_ema:
|
|
316
|
+
return self.ema_model(x)
|
|
317
|
+
|
|
318
|
+
if not self.training:
|
|
319
|
+
return self.asac(x)
|
|
320
|
+
|
|
321
|
+
# get EMA targets
|
|
322
|
+
with torch.no_grad():
|
|
323
|
+
self.ema_model.eval()
|
|
324
|
+
ema_outputs = self.ema_model.ema_model(x)
|
|
325
|
+
ema_targets = [sim.detach() for sim in ema_outputs.dot_sims]
|
|
326
|
+
self.ema_model.train()
|
|
327
|
+
|
|
328
|
+
return self.asac(x, attn_schema_targets = ema_targets)
|
|
329
|
+
|
|
330
|
+
def update(self):
|
|
331
|
+
self.ema_model.update()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ASAC.ASAC import ASAC, PatchEmbedding, EMA_ASAC
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from ASAC.ASAC import ASAC, PatchEmbedding
|
|
File without changes
|
|
File without changes
|
|
File without changes
|