ASAC-pytorch 0.0.9__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/ASAC/ASAC.py +85 -19
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/PKG-INFO +1 -1
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/pyproject.toml +1 -1
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/.gitignore +0 -0
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/ASAC/__init__.py +0 -0
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/LICENSE +0 -0
- {asac_pytorch-0.0.9 → asac_pytorch-0.0.10}/README.md +0 -0
|
@@ -6,10 +6,10 @@ from torch import nn, tensor
|
|
|
6
6
|
from torch.nn import Module, Linear, ModuleList
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
|
|
9
|
-
from einops import einsum, reduce
|
|
9
|
+
from einops import einsum, reduce, rearrange
|
|
10
10
|
from einops.layers.torch import Rearrange
|
|
11
11
|
|
|
12
|
-
from x_transformers import Decoder
|
|
12
|
+
from x_transformers import Decoder, AutoregressiveWrapper, TransformerWrapper
|
|
13
13
|
|
|
14
14
|
from x_mlps_pytorch import MLP
|
|
15
15
|
|
|
@@ -29,8 +29,8 @@ def default(v, d):
|
|
|
29
29
|
|
|
30
30
|
# return types
|
|
31
31
|
|
|
32
|
-
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', '
|
|
33
|
-
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', '
|
|
32
|
+
AttentionReturn = namedtuple('AttentionReturn', ['attended', 'indices', 'aux_loss', 'aux_loss_breakdown', 'attn_sim'])
|
|
33
|
+
ASACReturn = namedtuple('ASACReturn', ['logits', 'aux_loss', 'aux_loss_breakdown', 'attn_sims', 'attn_schema_indices', 'attn_schema_autoregressive_loss'])
|
|
34
34
|
|
|
35
35
|
# feedforward
|
|
36
36
|
|
|
@@ -65,7 +65,8 @@ class Attention(Module):
|
|
|
65
65
|
heads = 8,
|
|
66
66
|
k_rmsnorm = True,
|
|
67
67
|
attn_schema: Module | None = None,
|
|
68
|
-
attn_add_residual = True # they had to add a residual for stability
|
|
68
|
+
attn_add_residual = True, # they had to add a residual for stability
|
|
69
|
+
stochastic_sample_attn = False
|
|
69
70
|
):
|
|
70
71
|
super().__init__()
|
|
71
72
|
self.scale = dim_head ** -0.5
|
|
@@ -84,6 +85,8 @@ class Attention(Module):
|
|
|
84
85
|
self.attn_schema = attn_schema
|
|
85
86
|
self.attn_add_residual = attn_add_residual and attn_schema
|
|
86
87
|
|
|
88
|
+
self.stochastic_sample_attn = stochastic_sample_attn
|
|
89
|
+
|
|
87
90
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
88
91
|
|
|
89
92
|
def forward(
|
|
@@ -127,7 +130,10 @@ class Attention(Module):
|
|
|
127
130
|
|
|
128
131
|
# attend
|
|
129
132
|
|
|
130
|
-
|
|
133
|
+
if self.stochastic_sample_attn:
|
|
134
|
+
attn = F.gumbel_softmax(sim, tau = 1., hard = True, dim = -1)
|
|
135
|
+
else:
|
|
136
|
+
attn = sim.softmax(dim = -1)
|
|
131
137
|
|
|
132
138
|
# modulate
|
|
133
139
|
|
|
@@ -145,7 +151,13 @@ class Attention(Module):
|
|
|
145
151
|
|
|
146
152
|
attended = inverse_pack(attended)
|
|
147
153
|
|
|
148
|
-
return AttentionReturn(
|
|
154
|
+
return AttentionReturn(
|
|
155
|
+
attended,
|
|
156
|
+
indices,
|
|
157
|
+
aux_loss,
|
|
158
|
+
aux_loss_breakdown,
|
|
159
|
+
orig_sim
|
|
160
|
+
)
|
|
149
161
|
|
|
150
162
|
# attention autoencoder
|
|
151
163
|
|
|
@@ -245,10 +257,15 @@ class ASAC(Module):
|
|
|
245
257
|
vq_codebook_size = 256,
|
|
246
258
|
recon_loss_weight = 1.,
|
|
247
259
|
commit_loss_weight = 1.,
|
|
248
|
-
kl_div_loss = True
|
|
260
|
+
kl_div_loss = True,
|
|
261
|
+
stochastic_sample_attn = False,
|
|
262
|
+
awareness_model_depth = 2,
|
|
263
|
+
**awareness_model_kwargs
|
|
249
264
|
):
|
|
250
265
|
super().__init__()
|
|
251
266
|
|
|
267
|
+
assert depth >= 2, 'depth must be at least 2'
|
|
268
|
+
|
|
252
269
|
self.depth = depth
|
|
253
270
|
|
|
254
271
|
self.to_embedding = to_embedding
|
|
@@ -267,10 +284,34 @@ class ASAC(Module):
|
|
|
267
284
|
) if use_asac and exists(seq_len) else None
|
|
268
285
|
|
|
269
286
|
self.layers.append(ModuleList([
|
|
270
|
-
Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema),
|
|
287
|
+
Attention(dim, dim_head = dim_head, heads = heads, attn_schema = attn_schema, stochastic_sample_attn = stochastic_sample_attn),
|
|
271
288
|
FeedForward(dim)
|
|
272
289
|
]))
|
|
273
290
|
|
|
291
|
+
# autoregressive awareness model (attention schema theory)
|
|
292
|
+
|
|
293
|
+
self.awareness_transformer = None
|
|
294
|
+
self.awareness_model = None
|
|
295
|
+
|
|
296
|
+
if use_asac and exists(seq_len):
|
|
297
|
+
self.awareness_transformer = TransformerWrapper(
|
|
298
|
+
num_tokens = vq_codebook_size,
|
|
299
|
+
max_seq_len = depth,
|
|
300
|
+
attn_layers = Decoder(
|
|
301
|
+
dim = dim,
|
|
302
|
+
depth = awareness_model_depth,
|
|
303
|
+
heads = heads,
|
|
304
|
+
rotary_pos_emb = True,
|
|
305
|
+
**awareness_model_kwargs
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
self.awareness_model = AutoregressiveWrapper(self.awareness_transformer)
|
|
310
|
+
|
|
311
|
+
# zero buffer for auxiliary losses
|
|
312
|
+
|
|
313
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
314
|
+
|
|
274
315
|
self.to_logits = nn.Sequential(
|
|
275
316
|
nn.RMSNorm(dim),
|
|
276
317
|
Linear(dim, num_classes)
|
|
@@ -285,12 +326,15 @@ class ASAC(Module):
|
|
|
285
326
|
total_aux_loss = total_recon_loss = total_commit_loss = 0.
|
|
286
327
|
|
|
287
328
|
attn_schema_targets = default(attn_schema_targets, [None] * self.depth)
|
|
288
|
-
|
|
329
|
+
attn_sims = []
|
|
330
|
+
attn_schema_indices = []
|
|
289
331
|
|
|
290
332
|
for (attn, ff), target in zip(self.layers, attn_schema_targets):
|
|
291
|
-
attn_out, indices, aux_loss, (recon_loss, commit_loss),
|
|
333
|
+
attn_out, indices, aux_loss, (recon_loss, commit_loss), attn_sim = attn(x, attn_schema_target = target)
|
|
292
334
|
|
|
293
|
-
|
|
335
|
+
attn_sims.append(attn_sim)
|
|
336
|
+
if exists(indices):
|
|
337
|
+
attn_schema_indices.append(indices)
|
|
294
338
|
|
|
295
339
|
x = attn_out + x
|
|
296
340
|
x = ff(x) + x
|
|
@@ -303,14 +347,39 @@ class ASAC(Module):
|
|
|
303
347
|
|
|
304
348
|
logits = self.to_logits(x)
|
|
305
349
|
|
|
306
|
-
|
|
350
|
+
attn_schema_autoregressive_loss = self.zero
|
|
351
|
+
|
|
352
|
+
if attn_schema_indices:
|
|
353
|
+
attn_schema_indices = rearrange(attn_schema_indices, 'depth b ... -> b (depth ...)')
|
|
354
|
+
|
|
355
|
+
if exists(self.awareness_model):
|
|
356
|
+
attn_schema_autoregressive_loss = self.awareness_model(attn_schema_indices)
|
|
357
|
+
else:
|
|
358
|
+
attn_schema_indices = None
|
|
359
|
+
|
|
360
|
+
return ASACReturn(
|
|
361
|
+
logits,
|
|
362
|
+
total_aux_loss,
|
|
363
|
+
(total_recon_loss / self.depth, total_commit_loss / self.depth),
|
|
364
|
+
attn_sims,
|
|
365
|
+
attn_schema_indices,
|
|
366
|
+
attn_schema_autoregressive_loss
|
|
367
|
+
)
|
|
307
368
|
|
|
308
369
|
class EMA_ASAC(Module):
|
|
309
|
-
def __init__(
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
asac_model,
|
|
373
|
+
ema_decay = 0.999,
|
|
374
|
+
**ema_kwargs
|
|
375
|
+
):
|
|
310
376
|
super().__init__()
|
|
311
377
|
self.asac = asac_model
|
|
312
378
|
self.ema_model = EMA(asac_model, beta = ema_decay, **ema_kwargs)
|
|
313
379
|
|
|
380
|
+
def update(self):
|
|
381
|
+
self.ema_model.update()
|
|
382
|
+
|
|
314
383
|
def forward(self, x, use_ema = False):
|
|
315
384
|
if use_ema:
|
|
316
385
|
return self.ema_model(x)
|
|
@@ -321,11 +390,8 @@ class EMA_ASAC(Module):
|
|
|
321
390
|
# get EMA targets
|
|
322
391
|
with torch.no_grad():
|
|
323
392
|
self.ema_model.eval()
|
|
324
|
-
ema_outputs = self.ema_model
|
|
325
|
-
ema_targets = [sim.detach() for sim in ema_outputs.
|
|
393
|
+
ema_outputs = self.ema_model(x)
|
|
394
|
+
ema_targets = [sim.detach() for sim in ema_outputs.attn_sims]
|
|
326
395
|
self.ema_model.train()
|
|
327
396
|
|
|
328
397
|
return self.asac(x, attn_schema_targets = ema_targets)
|
|
329
|
-
|
|
330
|
-
def update(self):
|
|
331
|
-
self.ema_model.update()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|