ASAC-pytorch 0.0.4__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/ASAC/ASAC.py +14 -10
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/PKG-INFO +2 -2
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/pyproject.toml +2 -2
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/.gitignore +0 -0
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/ASAC/__init__.py +0 -0
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/LICENSE +0 -0
- {asac_pytorch-0.0.4 → asac_pytorch-0.0.7}/README.md +0 -0
|
@@ -117,7 +117,7 @@ class Attention(Module):
|
|
|
117
117
|
sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
|
|
118
118
|
|
|
119
119
|
if self.attn_add_residual:
|
|
120
|
-
sim = sim + orig_sim
|
|
120
|
+
sim = (sim + orig_sim) * 0.5
|
|
121
121
|
|
|
122
122
|
# modulate
|
|
123
123
|
|
|
@@ -164,14 +164,14 @@ class AttentionSchema(Module):
|
|
|
164
164
|
super().__init__()
|
|
165
165
|
|
|
166
166
|
if not exists(encoder):
|
|
167
|
-
encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
|
|
167
|
+
encoder = MLP(dim, dim_bottleneck, dim_bottleneck, activation = nn.LeakyReLU())
|
|
168
168
|
|
|
169
169
|
self.encoder = encoder
|
|
170
170
|
|
|
171
171
|
self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
|
|
172
172
|
|
|
173
173
|
if not exists(decoder):
|
|
174
|
-
decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
|
|
174
|
+
decoder = MLP(dim_bottleneck, dim_bottleneck, dim, activation = nn.LeakyReLU())
|
|
175
175
|
|
|
176
176
|
self.decoder = decoder
|
|
177
177
|
|
|
@@ -181,6 +181,8 @@ class AttentionSchema(Module):
|
|
|
181
181
|
self.recon_loss_weight = recon_loss_weight
|
|
182
182
|
self.commit_loss_weight = commit_loss_weight
|
|
183
183
|
|
|
184
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
185
|
+
|
|
184
186
|
def forward(
|
|
185
187
|
self,
|
|
186
188
|
attn_sim,
|
|
@@ -200,6 +202,8 @@ class AttentionSchema(Module):
|
|
|
200
202
|
|
|
201
203
|
# loss, mse as in paper or reverse kl
|
|
202
204
|
|
|
205
|
+
recon_loss = self.zero
|
|
206
|
+
|
|
203
207
|
if return_loss:
|
|
204
208
|
if self.detach_target:
|
|
205
209
|
attn_sim = attn_sim.detach()
|
|
@@ -208,8 +212,8 @@ class AttentionSchema(Module):
|
|
|
208
212
|
recon_loss = F.kl_div(
|
|
209
213
|
attn_sim.log_softmax(dim = -1),
|
|
210
214
|
recon.softmax(dim = -1),
|
|
211
|
-
reduction = '
|
|
212
|
-
)
|
|
215
|
+
reduction = 'none'
|
|
216
|
+
).sum(dim = -1).mean()
|
|
213
217
|
else:
|
|
214
218
|
recon_loss = F.mse_loss(recon, attn_sim)
|
|
215
219
|
|
|
@@ -236,7 +240,8 @@ class ASAC(Module):
|
|
|
236
240
|
dim_bottleneck = 256,
|
|
237
241
|
vq_codebook_size = 256,
|
|
238
242
|
recon_loss_weight = 1.,
|
|
239
|
-
commit_loss_weight = 1
|
|
243
|
+
commit_loss_weight = 1.,
|
|
244
|
+
kl_div_loss = True
|
|
240
245
|
):
|
|
241
246
|
super().__init__()
|
|
242
247
|
|
|
@@ -253,7 +258,8 @@ class ASAC(Module):
|
|
|
253
258
|
dim_bottleneck = dim_bottleneck,
|
|
254
259
|
codebook_size = vq_codebook_size,
|
|
255
260
|
recon_loss_weight = recon_loss_weight,
|
|
256
|
-
commit_loss_weight = commit_loss_weight
|
|
261
|
+
commit_loss_weight = commit_loss_weight,
|
|
262
|
+
kl_div_loss = kl_div_loss
|
|
257
263
|
) if use_asac and exists(seq_len) else None
|
|
258
264
|
|
|
259
265
|
self.layers.append(ModuleList([
|
|
@@ -272,9 +278,7 @@ class ASAC(Module):
|
|
|
272
278
|
if exists(self.pos_embedding):
|
|
273
279
|
x = x + self.pos_embedding
|
|
274
280
|
|
|
275
|
-
total_aux_loss = 0.
|
|
276
|
-
total_recon_loss = 0.
|
|
277
|
-
total_commit_loss = 0.
|
|
281
|
+
total_aux_loss = total_recon_loss = total_commit_loss = 0.
|
|
278
282
|
|
|
279
283
|
for attn, ff in self.layers:
|
|
280
284
|
attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ASAC-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Implementation of Attention Schema-based Attention Control (ASAC)
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/ASAC/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
|
|
@@ -38,7 +38,7 @@ Requires-Dist: einops>=0.8.1
|
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
39
|
Requires-Dist: ema-pytorch
|
|
40
40
|
Requires-Dist: torch-einops-utils>=0.1.2
|
|
41
|
-
Requires-Dist: torch>=2.
|
|
41
|
+
Requires-Dist: torch>=2.4
|
|
42
42
|
Requires-Dist: vector-quantize-pytorch
|
|
43
43
|
Requires-Dist: x-mlps-pytorch
|
|
44
44
|
Requires-Dist: x-transformers
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "ASAC-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.7"
|
|
4
4
|
description = "Implementation of Attention Schema-based Attention Control (ASAC)"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -27,7 +27,7 @@ dependencies = [
|
|
|
27
27
|
"einx>=0.3.0",
|
|
28
28
|
"einops>=0.8.1",
|
|
29
29
|
"ema-pytorch",
|
|
30
|
-
"torch>=2.
|
|
30
|
+
"torch>=2.4",
|
|
31
31
|
"torch-einops-utils>=0.1.2",
|
|
32
32
|
"vector-quantize-pytorch",
|
|
33
33
|
"x-transformers",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|