ASAC-pytorch 0.0.7__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,8 +29,8 @@ def default(v, d):
29
29
 
30
30
  # return types
31
31
 
32
- AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
33
- ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
32
+ AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'dot_sim'])
33
+ ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'dot_sims'])
34
34
 
35
35
  # feedforward
36
36
 
@@ -90,7 +90,8 @@ class Attention(Module):
90
90
  self,
91
91
  tokens, # (b h w d)
92
92
  pre_softmax_attn_gates = None,
93
- post_softmax_attn_gates = None
93
+ post_softmax_attn_gates = None,
94
+ attn_schema_target = None
94
95
  ):
95
96
  tokens = self.norm(tokens)
96
97
 
@@ -114,7 +115,7 @@ class Attention(Module):
114
115
  indices = None
115
116
 
116
117
  if exists(self.attn_schema):
117
- sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
118
+ sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim, target_sim = attn_schema_target)
118
119
 
119
120
  if self.attn_add_residual:
120
121
  sim = (sim + orig_sim) * 0.5
@@ -144,7 +145,7 @@ class Attention(Module):
144
145
 
145
146
  attended = inverse_pack(attended)
146
147
 
147
- return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown)
148
+ return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown, orig_sim)
148
149
 
149
150
  # attention autoencoder
150
151
 
@@ -186,7 +187,8 @@ class AttentionSchema(Module):
186
187
  def forward(
187
188
  self,
188
189
  attn_sim,
189
- return_loss = None
190
+ return_loss = None,
191
+ target_sim = None
190
192
  ):
191
193
  return_loss = default(return_loss, self.training)
192
194
 
@@ -204,18 +206,20 @@ class AttentionSchema(Module):
204
206
 
205
207
  recon_loss = self.zero
206
208
 
209
+ target = default(target_sim, attn_sim)
210
+
207
211
  if return_loss:
208
212
  if self.detach_target:
209
- attn_sim = attn_sim.detach()
213
+ target = target.detach()
210
214
 
211
215
  if self.kl_div_loss:
212
216
  recon_loss = F.kl_div(
213
- attn_sim.log_softmax(dim = -1),
217
+ target.log_softmax(dim = -1),
214
218
  recon.softmax(dim = -1),
215
219
  reduction = 'none'
216
220
  ).sum(dim = -1).mean()
217
221
  else:
218
- recon_loss = F.mse_loss(recon, attn_sim)
222
+ recon_loss = F.mse_loss(recon, target)
219
223
 
220
224
  # total
221
225
 
@@ -272,7 +276,7 @@ class ASAC(Module):
272
276
  Linear(dim, num_classes)
273
277
  )
274
278
 
275
- def forward(self, x):
279
+ def forward(self, x, attn_schema_targets = None):
276
280
  x = self.to_embedding(x)
277
281
 
278
282
  if exists(self.pos_embedding):
@@ -280,8 +284,13 @@ class ASAC(Module):
280
284
 
281
285
  total_aux_loss = total_recon_loss = total_commit_loss = 0.
282
286
 
283
- for attn, ff in self.layers:
284
- attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
287
+ attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
288
+ dot_sims = []
289
+
290
+ for (attn, ff), target in zip(self.layers, attn_schema_targets):
291
+ attn_out, indices, aux_loss, (recon_loss, commit_loss), dot_sim = attn(x, attn_schema_target = target)
292
+
293
+ dot_sims.append(dot_sim)
285
294
 
286
295
  x = attn_out + x
287
296
  x = ff(x) + x
@@ -294,4 +303,29 @@ class ASAC(Module):
294
303
 
295
304
  logits = self.to_logits(x)
296
305
 
297
- return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth))
306
+ return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth), dot_sims)
307
+
308
+ class EMA_ASAC(Module):
309
+ def __init__(self, asac_model, ema_decay = 0.999, **ema_kwargs):
310
+ super().__init__()
311
+ self.asac = asac_model
312
+ self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
313
+
314
+ def forward(self, x, use_ema = False):
315
+ if use_ema:
316
+ return self.ema_model(x)
317
+
318
+ if not self.training:
319
+ return self.asac(x)
320
+
321
+ # get EMA targets
322
+ with torch.no_grad():
323
+ self.ema_model.eval()
324
+ ema_outputs = self.ema_model.ema_model(x)
325
+ ema_targets = [sim.detach() for sim in ema_outputs.dot_sims]
326
+ self.ema_model.train()
327
+
328
+ return self.asac(x, attn_schema_targets = ema_targets)
329
+
330
+ def update(self):
331
+ self.ema_model.update()
@@ -0,0 +1 @@
1
+ from ASAC.ASAC import ASAC, PatchEmbedding, EMA_ASAC
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.7"
3
+ version = "0.0.9"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1 +0,0 @@
1
- from ASAC.ASAC import ASAC, PatchEmbedding
File without changes
File without changes
File without changes