ASAC-pytorch 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asac_pytorch-0.0.2/ASAC/ASAC.py +165 -0
- asac_pytorch-0.0.2/ASAC/__init__.py +1 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.2}/PKG-INFO +5 -1
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.2}/pyproject.toml +5 -1
- asac_pytorch-0.0.1/ASAC/ASAC.py +0 -0
- asac_pytorch-0.0.1/ASAC/__init__.py +0 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.2}/.gitignore +0 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.2}/LICENSE +0 -0
- {asac_pytorch-0.0.1 → asac_pytorch-0.0.2}/README.md +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, tensor
|
|
5
|
+
from torch.nn import Module, Linear
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from einops import einsum
|
|
9
|
+
from einops.layers.torch import Rearrange
|
|
10
|
+
|
|
11
|
+
from x_transformers import Decoder
|
|
12
|
+
|
|
13
|
+
from x_mlps_pytorch import MLP
|
|
14
|
+
|
|
15
|
+
from vector_quantize_pytorch import VectorQuantize
|
|
16
|
+
|
|
17
|
+
from ema_pytorch import EMA
|
|
18
|
+
|
|
19
|
+
from torch_einops_utils import pack_with_inverse
|
|
20
|
+
|
|
21
|
+
# helpers
|
|
22
|
+
|
|
23
|
+
def exists(v):
|
|
24
|
+
return v is not None
|
|
25
|
+
|
|
26
|
+
def default(v, d):
|
|
27
|
+
return v if exists(v) else d
|
|
28
|
+
|
|
29
|
+
# attention
|
|
30
|
+
|
|
31
|
+
class Attention(Module):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
dim,
|
|
35
|
+
dim_head = 64,
|
|
36
|
+
heads = 8,
|
|
37
|
+
attn_schema: Module | None = None,
|
|
38
|
+
attn_add_residual = True # they had to add a residual for stability
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.scale = dim_head ** -0.5
|
|
42
|
+
dim_inner = dim_head * heads
|
|
43
|
+
|
|
44
|
+
self.to_qkv = Linear(dim, dim_inner * 3, bias = False)
|
|
45
|
+
self.combine_heads = Linear(dim_inner, dim, bias = False)
|
|
46
|
+
|
|
47
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
48
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
49
|
+
|
|
50
|
+
self.attn_schema = attn_schema
|
|
51
|
+
self.attn_add_residual = attn_add_residual and attn_schema
|
|
52
|
+
|
|
53
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
54
|
+
|
|
55
|
+
def forward(
|
|
56
|
+
self,
|
|
57
|
+
tokens, # (b h w d)
|
|
58
|
+
):
|
|
59
|
+
tokens, inverse_pack = pack_with_inverse(tokens, 'b * d')
|
|
60
|
+
|
|
61
|
+
q, k, v = self.to_qkv(tokens).chunk(3, dim = -1)
|
|
62
|
+
q, k, v = (self.split_heads(t) for t in (q, k, v))
|
|
63
|
+
|
|
64
|
+
q = q * self.scale
|
|
65
|
+
|
|
66
|
+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
|
67
|
+
|
|
68
|
+
orig_sim = sim
|
|
69
|
+
|
|
70
|
+
# the proposal
|
|
71
|
+
|
|
72
|
+
aux_loss = self.zero
|
|
73
|
+
|
|
74
|
+
if exists(self.attn_schema):
|
|
75
|
+
sim, indices, aux_loss = self.attn_schema(orig_sim)
|
|
76
|
+
|
|
77
|
+
if self.attn_add_residual:
|
|
78
|
+
sim = sim + orig_sim
|
|
79
|
+
|
|
80
|
+
# attend
|
|
81
|
+
|
|
82
|
+
attn = sim.softmax(dim = -1)
|
|
83
|
+
|
|
84
|
+
# aggregate and combine out
|
|
85
|
+
|
|
86
|
+
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
87
|
+
|
|
88
|
+
out = self.merge_heads(out)
|
|
89
|
+
attended = self.combine_heads(out)
|
|
90
|
+
|
|
91
|
+
# bring back the packed dimensions
|
|
92
|
+
|
|
93
|
+
attended = inverse_pack(attended)
|
|
94
|
+
|
|
95
|
+
return attended, indices, aux_loss
|
|
96
|
+
|
|
97
|
+
# attention autoencoder
|
|
98
|
+
|
|
99
|
+
class AttentionSchema(Module):
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
dim,
|
|
103
|
+
dim_bottleneck,
|
|
104
|
+
kl_div_loss = True,
|
|
105
|
+
detach_target = True,
|
|
106
|
+
**vq_kwargs
|
|
107
|
+
):
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
|
|
110
|
+
|
|
111
|
+
self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
|
|
112
|
+
|
|
113
|
+
self.decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
|
|
114
|
+
|
|
115
|
+
self.kl_div_loss = kl_div_loss
|
|
116
|
+
self.detach_target = detach_target
|
|
117
|
+
|
|
118
|
+
def forward(
|
|
119
|
+
self,
|
|
120
|
+
attn_sim,
|
|
121
|
+
return_loss = None
|
|
122
|
+
):
|
|
123
|
+
return_loss = default(return_loss, self.training)
|
|
124
|
+
|
|
125
|
+
attn_features, inverse_pack = pack_with_inverse(attn_sim, 'b *')
|
|
126
|
+
|
|
127
|
+
encoded = self.encoder(attn_features)
|
|
128
|
+
|
|
129
|
+
quantized, indices, commit_loss = self.vq(encoded)
|
|
130
|
+
|
|
131
|
+
decoded = self.decoder(quantized)
|
|
132
|
+
|
|
133
|
+
recon = inverse_pack(decoded)
|
|
134
|
+
|
|
135
|
+
# loss, mse as in paper or reverse kl
|
|
136
|
+
|
|
137
|
+
if return_loss:
|
|
138
|
+
if self.detach_target:
|
|
139
|
+
attn_sim = attn_sim.detach()
|
|
140
|
+
|
|
141
|
+
if self.kl_div_loss:
|
|
142
|
+
recon_loss = F.kl_div(
|
|
143
|
+
attn_sim.log_softmax(dim = -1),
|
|
144
|
+
recon,
|
|
145
|
+
reduction = 'batchmean'
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
recon_loss = F.mse_loss(recon, attn_sim)
|
|
149
|
+
|
|
150
|
+
# total
|
|
151
|
+
|
|
152
|
+
total_loss = recon_loss + commit_loss
|
|
153
|
+
|
|
154
|
+
loss_breakdown = (recon_loss, commit_loss)
|
|
155
|
+
|
|
156
|
+
return recon, indices, total_loss
|
|
157
|
+
|
|
158
|
+
# class
|
|
159
|
+
|
|
160
|
+
class ASAC(Module):
|
|
161
|
+
def __init__(self):
|
|
162
|
+
super().__init__()
|
|
163
|
+
|
|
164
|
+
def forward(self, x):
|
|
165
|
+
return x
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ASAC.ASAC import ASAC
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
@@ -36,8 +36,12 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.10
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
|
+
Requires-Dist: ema-pytorch
|
|
39
40
|
Requires-Dist: torch-einops-utils>=0.1.2
|
|
40
41
|
Requires-Dist: torch>=2.5
|
|
42
|
+
Requires-Dist: vector-quantize-pytorch
|
|
43
|
+
Requires-Dist: x-mlps-pytorch
|
|
44
|
+
Requires-Dist: x-transformers
|
|
41
45
|
Provides-Extra: examples
|
|
42
46
|
Provides-Extra: test
|
|
43
47
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "ASAC-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.2"
|
|
4
4
|
description = "Implementation of Attention Schema-based Attention Control (ASAC)"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -26,8 +26,12 @@ classifiers=[
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"einx>=0.3.0",
|
|
28
28
|
"einops>=0.8.1",
|
|
29
|
+
"ema-pytorch",
|
|
29
30
|
"torch>=2.5",
|
|
30
31
|
"torch-einops-utils>=0.1.2",
|
|
32
|
+
"vector-quantize-pytorch",
|
|
33
|
+
"x-transformers",
|
|
34
|
+
"x-mlps-pytorch"
|
|
31
35
|
]
|
|
32
36
|
|
|
33
37
|
[project.urls]
|
asac_pytorch-0.0.1/ASAC/ASAC.py
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|