ASAC-pytorch 0.0.9__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,10 +6,10 @@ from torch import nn, tensor
6
6
  from torch.nn import Module, Linear, ModuleList
7
7
  import torch.nn.functional as F
8
8
 
9
- from einops import einsum, reduce
9
+ from einops import einsum, reduce, rearrange
10
10
  from einops.layers.torch import Rearrange
11
11
 
12
- from x_transformers import Decoder
12
+ from x_transformers import Decoder, AutoregressiveWrapper, TransformerWrapper
13
13
 
14
14
  from x_mlps_pytorch import MLP
15
15
 
@@ -29,8 +29,8 @@ def default(v, d):
29
29
 
30
30
  # return types
31
31
 
32
- AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'dot_sim'])
33
- ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'dot_sims'])
32
+ AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'attn_sim'])
33
+ ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'attn_sims', 'attn_schema_indices', 'attn_schema_autoregressive_loss'])
34
34
 
35
35
  # feedforward
36
36
 
@@ -65,7 +65,8 @@ class Attention(Module):
65
65
  heads = 8,
66
66
  k_rmsnorm = True,
67
67
  attn_schema: Module | None = None,
68
- attn_add_residual = True # they had to add a residual for stability
68
+ attn_add_residual = True, # they had to add a residual for stability
69
+ stochastic_sample_attn = False
69
70
  ):
70
71
  super().__init__()
71
72
  self.scale = dim_head ** -0.5
@@ -84,6 +85,8 @@ class Attention(Module):
84
85
  self.attn_schema = attn_schema
85
86
  self.attn_add_residual = attn_add_residual and attn_schema
86
87
 
88
+ self.stochastic_sample_attn = stochastic_sample_attn
89
+
87
90
  self.register_buffer('zero', tensor(0.), persistent = False)
88
91
 
89
92
  def forward(
@@ -127,7 +130,10 @@ class Attention(Module):
127
130
 
128
131
  # attend
129
132
 
130
- attn = sim.softmax(dim = -1)
133
+ if self.stochastic_sample_attn:
134
+ attn = F.gumbel_softmax(sim, tau = 1., hard = True, dim = -1)
135
+ else:
136
+ attn = sim.softmax(dim = -1)
131
137
 
132
138
  # modulate
133
139
 
@@ -145,7 +151,13 @@ class Attention(Module):
145
151
 
146
152
  attended = inverse_pack(attended)
147
153
 
148
- return AttentionReturn(attended, indices, aux_loss, aux_loss_breakdown, orig_sim)
154
+ return AttentionReturn(
155
+ attended,
156
+ indices,
157
+ aux_loss,
158
+ aux_loss_breakdown,
159
+ orig_sim
160
+ )
149
161
 
150
162
  # attention autoencoder
151
163
 
@@ -245,10 +257,15 @@ class ASAC(Module):
245
257
  vq_codebook_size = 256,
246
258
  recon_loss_weight = 1.,
247
259
  commit_loss_weight = 1.,
248
- kl_div_loss = True
260
+ kl_div_loss = True,
261
+ stochastic_sample_attn = False,
262
+ awareness_model_depth = 2,
263
+ **awareness_model_kwargs
249
264
  ):
250
265
  super().__init__()
251
266
 
267
+ assert depth >= 2, 'depth must be at least 2'
268
+
252
269
  self.depth = depth
253
270
 
254
271
  self.to_embedding = to_embedding
@@ -267,10 +284,34 @@ class ASAC(Module):
267
284
  ) if use_asac and exists(seq_len) else None
268
285
 
269
286
  self.layers.append(ModuleList([
270
- Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
287
+ Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema, stochastic_sample_attn = stochastic_sample_attn),
271
288
  FeedForward(dim)
272
289
  ]))
273
290
 
291
+ # autoregressive awareness model (attention schema theory)
292
+
293
+ self.awareness_transformer = None
294
+ self.awareness_model = None
295
+
296
+ if use_asac and exists(seq_len):
297
+ self.awareness_transformer = TransformerWrapper(
298
+ num_tokens = vq_codebook_size,
299
+ max_seq_len = depth,
300
+ attn_layers = Decoder(
301
+ dim = dim,
302
+ depth = awareness_model_depth,
303
+ heads = heads,
304
+ rotary_pos_emb = True,
305
+ **awareness_model_kwargs
306
+ )
307
+ )
308
+
309
+ self.awareness_model = AutoregressiveWrapper(self.awareness_transformer)
310
+
311
+ # zero buffer for auxiliary losses
312
+
313
+ self.register_buffer('zero', tensor(0.), persistent = False)
314
+
274
315
  self.to_logits = nn.Sequential(
275
316
  nn.RMSNorm(dim),
276
317
  Linear(dim, num_classes)
@@ -285,12 +326,15 @@ class ASAC(Module):
285
326
  total_aux_loss = total_recon_loss = total_commit_loss = 0.
286
327
 
287
328
  attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
288
- dot_sims = []
329
+ attn_sims = []
330
+ attn_schema_indices = []
289
331
 
290
332
  for (attn, ff), target in zip(self.layers, attn_schema_targets):
291
- attn_out, indices, aux_loss, (recon_loss, commit_loss), dot_sim = attn(x, attn_schema_target = target)
333
+ attn_out, indices, aux_loss, (recon_loss, commit_loss), attn_sim = attn(x, attn_schema_target = target)
292
334
 
293
- dot_sims.append(dot_sim)
335
+ attn_sims.append(attn_sim)
336
+ if exists(indices):
337
+ attn_schema_indices.append(indices)
294
338
 
295
339
  x = attn_out + x
296
340
  x = ff(x) + x
@@ -303,14 +347,39 @@ class ASAC(Module):
303
347
 
304
348
  logits = self.to_logits(x)
305
349
 
306
- return ASACReturn(logits, total_aux_loss, (total_recon_loss / self.depth, total_commit_loss / self.depth), dot_sims)
350
+ attn_schema_autoregressive_loss = self.zero
351
+
352
+ if attn_schema_indices:
353
+ attn_schema_indices = rearrange(attn_schema_indices, 'depth b ... -> b (depth ...)')
354
+
355
+ if exists(self.awareness_model):
356
+ attn_schema_autoregressive_loss = self.awareness_model(attn_schema_indices)
357
+ else:
358
+ attn_schema_indices = None
359
+
360
+ return ASACReturn(
361
+ logits,
362
+ total_aux_loss,
363
+ (total_recon_loss / self.depth, total_commit_loss / self.depth),
364
+ attn_sims,
365
+ attn_schema_indices,
366
+ attn_schema_autoregressive_loss
367
+ )
307
368
 
308
369
  class EMA_ASAC(Module):
309
- def __init__(self, asac_model, ema_decay = 0.999, **ema_kwargs):
370
+ def __init__(
371
+ self,
372
+ asac_model,
373
+ ema_decay = 0.999,
374
+ **ema_kwargs
375
+ ):
310
376
  super().__init__()
311
377
  self.asac = asac_model
312
378
  self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
313
379
 
380
+ def update(self):
381
+ self.ema_model.update()
382
+
314
383
  def forward(self, x, use_ema = False):
315
384
  if use_ema:
316
385
  return self.ema_model(x)
@@ -321,11 +390,8 @@ class EMA_ASAC(Module):
321
390
  # get EMA targets
322
391
  with torch.no_grad():
323
392
  self.ema_model.eval()
324
- ema_outputs = self.ema_model.ema_model(x)
325
- ema_targets = [sim.detach() for sim in ema_outputs.dot_sims]
393
+ ema_outputs = self.ema_model(x)
394
+ ema_targets = [sim.detach() for sim in ema_outputs.attn_sims]
326
395
  self.ema_model.train()
327
396
 
328
397
  return self.asac(x, attn_schema_targets = ema_targets)
329
-
330
- def update(self):
331
- self.ema_model.update()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.9"
3
+ version = "0.0.10"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes