ASAC-pytorch 0.0.7__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,10 +6,10 @@ from torch import nn, tensor
6
6
  from torch.nn import Module, Linear, ModuleList
7
7
  import torch.nn.functional as F
8
8
 
9
- from einops import einsum, reduce
9
+ from einops import einsum, reduce, rearrange
10
10
  from einops.layers.torch import Rearrange
11
11
 
12
- from x_transformers import Decoder
12
+ from x_transformers import Decoder, AutoregressiveWrapper, TransformerWrapper
13
13
 
14
14
  from x_mlps_pytorch import MLP
15
15
 
@@ -29,8 +29,8 @@ def default(v, d):
29
29
 
30
30
  # return types
31
31
 
32
- AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown'])
33
- ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown'])
32
+ AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'attn_sim'])
33
+ ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'attn_sims', 'attn_schema_indices', 'attn_schema_autoregressive_loss'])
34
34
 
35
35
  # feedforward
36
36
 
@@ -65,7 +65,8 @@ class Attention(Module):
65
65
  heads = 8,
66
66
  k_rmsnorm = True,
67
67
  attn_schema: Module | None = None,
68
- attn_add_residual = True # they had to add a residual for stability
68
+ attn_add_residual = True, # they had to add a residual for stability
69
+ stochastic_sample_attn = False
69
70
  ):
70
71
  super().__init__()
71
72
  self.scale = dim_head ** -0.5
@@ -84,13 +85,16 @@ class Attention(Module):
84
85
  self.attn_schema = attn_schema
85
86
  self.attn_add_residual = attn_add_residual and attn_schema
86
87
 
88
+ self.stochastic_sample_attn = stochastic_sample_attn
89
+
87
90
  self.register_buffer('zero', tensor(0.), persistent = False)
88
91
 
89
92
  def forward(
90
93
  self,
91
94
  tokens, # (b h w d)
92
95
  pre_softmax_attn_gates = None,
93
- post_softmax_attn_gates = None
96
+ post_softmax_attn_gates = None,
97
+ attn_schema_target = None
94
98
  ):
95
99
  tokens = self.norm(tokens)
96
100
 
@@ -114,7 +118,7 @@ class Attention(Module):
114
118
  indices = None
115
119
 
116
120
  if exists(self.attn_schema):
117
- sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
121
+ sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim, target_sim = attn_schema_target)
118
122
 
119
123
  if self.attn_add_residual:
120
124
  sim = (sim + orig_sim) * 0.5
@@ -126,7 +130,10 @@ class Attention(Module):
126
130
 
127
131
  # attend
128
132
 
129
- attn = sim.softmax(dim = -1)
133
+ if self.stochastic_sample_attn:
134
+ attn = F.gumbel_softmax(sim, tau = 1., hard = True, dim = -1)
135
+ else:
136
+ attn = sim.softmax(dim = -1)
130
137
 
131
138
  # modulate
132
139
 
@@ -144,7 +151,13 @@ class Attention(Module):
144
151
 
145
152
  attended = inverse_pack(attended)
146
153
 
147
- return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown)
154
+ return AttentionReturn(
155
+ attended,
156
+ indices,
157
+ aux_loss,
158
+ aux_loss_breakdown,
159
+ orig_sim
160
+ )
148
161
 
149
162
  # attention autoencoder
150
163
 
@@ -186,7 +199,8 @@ class AttentionSchema(Module):
186
199
  def forward(
187
200
  self,
188
201
  attn_sim,
189
- return_loss = None
202
+ return_loss = None,
203
+ target_sim = None
190
204
  ):
191
205
  return_loss = default(return_loss, self.training)
192
206
 
@@ -204,18 +218,20 @@ class AttentionSchema(Module):
204
218
 
205
219
  recon_loss = self.zero
206
220
 
221
+ target = default(target_sim, attn_sim)
222
+
207
223
  if return_loss:
208
224
  if self.detach_target:
209
- attn_sim = attn_sim.detach()
225
+ target = target.detach()
210
226
 
211
227
  if self.kl_div_loss:
212
228
  recon_loss = F.kl_div(
213
- attn_sim.log_softmax(dim = -1),
229
+ target.log_softmax(dim = -1),
214
230
  recon.softmax(dim = -1),
215
231
  reduction = 'none'
216
232
  ).sum(dim = -1).mean()
217
233
  else:
218
- recon_loss = F.mse_loss(recon, attn_sim)
234
+ recon_loss = F.mse_loss(recon, target)
219
235
 
220
236
  # total
221
237
 
@@ -241,10 +257,15 @@ class ASAC(Module):
241
257
  vq_codebook_size = 256,
242
258
  recon_loss_weight = 1.,
243
259
  commit_loss_weight = 1.,
244
- kl_div_loss = True
260
+ kl_div_loss = True,
261
+ stochastic_sample_attn = False,
262
+ awareness_model_depth = 2,
263
+ **awareness_model_kwargs
245
264
  ):
246
265
  super().__init__()
247
266
 
267
+ assert depth >= 2, 'depth must be at least 2'
268
+
248
269
  self.depth = depth
249
270
 
250
271
  self.to_embedding = to_embedding
@@ -263,16 +284,40 @@ class ASAC(Module):
263
284
  ) if use_asac and exists(seq_len) else None
264
285
 
265
286
  self.layers.append(ModuleList([
266
- Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
287
+ Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema, stochastic_sample_attn = stochastic_sample_attn),
267
288
  FeedForward(dim)
268
289
  ]))
269
290
 
291
+ # autoregressive awareness model (attention schema theory)
292
+
293
+ self.awareness_transformer = None
294
+ self.awareness_model = None
295
+
296
+ if use_asac and exists(seq_len):
297
+ self.awareness_transformer = TransformerWrapper(
298
+ num_tokens = vq_codebook_size,
299
+ max_seq_len = depth,
300
+ attn_layers = Decoder(
301
+ dim = dim,
302
+ depth = awareness_model_depth,
303
+ heads = heads,
304
+ rotary_pos_emb = True,
305
+ **awareness_model_kwargs
306
+ )
307
+ )
308
+
309
+ self.awareness_model = AutoregressiveWrapper(self.awareness_transformer)
310
+
311
+ # zero buffer for auxiliary losses
312
+
313
+ self.register_buffer('zero', tensor(0.), persistent = False)
314
+
270
315
  self.to_logits = nn.Sequential(
271
316
  nn.RMSNorm(dim),
272
317
  Linear(dim, num_classes)
273
318
  )
274
319
 
275
- def forward(self, x):
320
+ def forward(self, x, attn_schema_targets = None):
276
321
  x = self.to_embedding(x)
277
322
 
278
323
  if exists(self.pos_embedding):
@@ -280,8 +325,16 @@ class ASAC(Module):
280
325
 
281
326
  total_aux_loss = total_recon_loss = total_commit_loss = 0.
282
327
 
283
- for attn, ff in self.layers:
284
- attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
328
+ attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
329
+ attn_sims = []
330
+ attn_schema_indices = []
331
+
332
+ for (attn, ff), target in zip(self.layers, attn_schema_targets):
333
+ attn_out, indices, aux_loss, (recon_loss, commit_loss), attn_sim = attn(x, attn_schema_target = target)
334
+
335
+ attn_sims.append(attn_sim)
336
+ if exists(indices):
337
+ attn_schema_indices.append(indices)
285
338
 
286
339
  x = attn_out + x
287
340
  x = ff(x) + x
@@ -294,4 +347,51 @@ class ASAC(Module):
294
347
 
295
348
  logits = self.to_logits(x)
296
349
 
297
- return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth))
350
+ attn_schema_autoregressive_loss = self.zero
351
+
352
+ if attn_schema_indices:
353
+ attn_schema_indices = rearrange(attn_schema_indices, 'depth b ... -> b (depth ...)')
354
+
355
+ if exists(self.awareness_model):
356
+ attn_schema_autoregressive_loss = self.awareness_model(attn_schema_indices)
357
+ else:
358
+ attn_schema_indices = None
359
+
360
+ return ASACReturn(
361
+ logits,
362
+ total_aux_loss,
363
+ (total_recon_loss / self.depth, total_commit_loss / self.depth),
364
+ attn_sims,
365
+ attn_schema_indices,
366
+ attn_schema_autoregressive_loss
367
+ )
368
+
369
+ class EMA_ASAC(Module):
370
+ def __init__(
371
+ self,
372
+ asac_model,
373
+ ema_decay = 0.999,
374
+ **ema_kwargs
375
+ ):
376
+ super().__init__()
377
+ self.asac = asac_model
378
+ self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
379
+
380
+ def update(self):
381
+ self.ema_model.update()
382
+
383
+ def forward(self, x, use_ema = False):
384
+ if use_ema:
385
+ return self.ema_model(x)
386
+
387
+ if not self.training:
388
+ return self.asac(x)
389
+
390
+ # get EMA targets
391
+ with torch.no_grad():
392
+ self.ema_model.eval()
393
+ ema_outputs = self.ema_model(x)
394
+ ema_targets = [sim.detach() for sim in ema_outputs.attn_sims]
395
+ self.ema_model.train()
396
+
397
+ return self.asac(x, attn_schema_targets = ema_targets)
@@ -0,0 +1 @@
1
+ from ASAC.ASAC import ASAC, PatchEmbedding, EMA_ASAC
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.7
3
+ Version: 0.0.10
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.7"
3
+ version = "0.0.10"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1 +0,0 @@
1
- from ASAC.ASAC import ASAC, PatchEmbedding
File without changes
File without changes
File without changes