ASAC-pytorch 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,7 +117,7 @@ class Attention(Module):
117
117
  sim, indices, aux_loss, aux_loss_breakdown = self.attn_schema(orig_sim)
118
118
 
119
119
  if self.attn_add_residual:
120
- sim = sim + orig_sim
120
+ sim = (sim + orig_sim) * 0.5
121
121
 
122
122
  # modulate
123
123
 
@@ -153,7 +153,7 @@ class AttentionSchema(Module):
153
153
  self,
154
154
  dim,
155
155
  dim_bottleneck,
156
- kl_div_loss = True,
156
+ kl_div_loss = False,
157
157
  detach_target = True,
158
158
  encoder: Module | None = None,
159
159
  decoder: Module | None = None,
@@ -181,6 +181,8 @@ class AttentionSchema(Module):
181
181
  self.recon_loss_weight = recon_loss_weight
182
182
  self.commit_loss_weight = commit_loss_weight
183
183
 
184
+ self.register_buffer('zero', tensor(0.), persistent = False)
185
+
184
186
  def forward(
185
187
  self,
186
188
  attn_sim,
@@ -200,6 +202,8 @@ class AttentionSchema(Module):
200
202
 
201
203
  # loss, mse as in paper or reverse kl
202
204
 
205
+ recon_loss = self.zero
206
+
203
207
  if return_loss:
204
208
  if self.detach_target:
205
209
  attn_sim = attn_sim.detach()
@@ -208,8 +212,8 @@ class AttentionSchema(Module):
208
212
  recon_loss = F.kl_div(
209
213
  attn_sim.log_softmax(dim = -1),
210
214
  recon.softmax(dim = -1),
211
- reduction = 'batchmean'
212
- )
215
+ reduction = 'none'
216
+ ).sum(dim = -1).mean()
213
217
  else:
214
218
  recon_loss = F.mse_loss(recon, attn_sim)
215
219
 
@@ -236,7 +240,8 @@ class ASAC(Module):
236
240
  dim_bottleneck = 256,
237
241
  vq_codebook_size = 256,
238
242
  recon_loss_weight = 1.,
239
- commit_loss_weight = 1.
243
+ commit_loss_weight = 1.,
244
+ kl_div_loss = False
240
245
  ):
241
246
  super().__init__()
242
247
 
@@ -253,7 +258,8 @@ class ASAC(Module):
253
258
  dim_bottleneck = dim_bottleneck,
254
259
  codebook_size = vq_codebook_size,
255
260
  recon_loss_weight = recon_loss_weight,
256
- commit_loss_weight = commit_loss_weight
261
+ commit_loss_weight = commit_loss_weight,
262
+ kl_div_loss = kl_div_loss
257
263
  ) if use_asac and exists(seq_len) else None
258
264
 
259
265
  self.layers.append(ModuleList([
@@ -272,9 +278,7 @@ class ASAC(Module):
272
278
  if exists(self.pos_embedding):
273
279
  x = x + self.pos_embedding
274
280
 
275
- total_aux_loss = 0.
276
- total_recon_loss = 0.
277
- total_commit_loss = 0.
281
+ total_aux_loss = total_recon_loss = total_commit_loss = 0.
278
282
 
279
283
  for attn, ff in self.layers:
280
284
  attn_out, indices, aux_loss, (recon_loss, commit_loss) = attn(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes