x-transformers 1.27.21__py3-none-any.whl → 1.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -889,7 +889,15 @@ class Attention(nn.Module):
889
889
  else:
890
890
  input_mask = torch.cat((mem_mask, input_mask), dim = -1)
891
891
 
892
- if self.num_mem_kv > 0:
892
+ # i, j determined for relative positional bias, excluding memory key / values
893
+
894
+ i, j = map(lambda t: t.shape[-2], (q, k))
895
+
896
+ # maybe append memory key / values
897
+
898
+ has_mem_kv = self.num_mem_kv > 0
899
+
900
+ if has_mem_kv:
893
901
  mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
894
902
 
895
903
  if self.qk_norm:
@@ -902,8 +910,6 @@ class Attention(nn.Module):
902
910
  if exists(input_mask):
903
911
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
904
912
 
905
- i, j = map(lambda t: t.shape[-2], (q, k))
906
-
907
913
  # determine masking
908
914
 
909
915
  mask_value = max_neg_value(q)
@@ -938,6 +944,11 @@ class Attention(nn.Module):
938
944
  if exists(rel_pos):
939
945
  attn_bias = rel_pos(i, j)
940
946
 
947
+ # append with no bias for memory key / values
948
+
949
+ if has_mem_kv:
950
+ attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
951
+
941
952
  # attention is all we need
942
953
 
943
954
  out, intermediates = self.attend(
x_transformers/xval.py CHANGED
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
260
260
  inp, target = x[:, :-1], x[:, 1:]
261
261
  x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
262
262
 
263
+ # ignore index
264
+
265
+ target_mask = target != self.ignore_index
266
+
267
+ # key padding mask
268
+
263
269
  mask = kwargs.get('mask', None)
264
- if exists(mask) and mask.shape[1] == x.shape[1]:
265
- mask = mask[:, :-1]
266
- kwargs['mask'] = mask
270
+ if exists(mask):
271
+ target_mask &= mask
272
+
273
+ if mask.shape[1] == x.shape[1]:
274
+ mask = mask[:, :-1]
275
+ kwargs['mask'] = mask
267
276
 
268
277
  logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
269
278
 
@@ -276,21 +285,18 @@ class XValAutoregressiveWrapper(nn.Module):
276
285
  target_is_number_mask = target == self.net.numerical_token_id
277
286
  x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
278
287
 
279
- # ignore index
280
-
281
- target_mask = target != self.ignore_index
282
-
283
288
  # numerical mse loss
284
289
 
285
290
  numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
286
291
 
287
292
  numerical_mse_loss = numerical_mse_loss * target_mask
293
+ numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
288
294
 
289
- loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
295
+ # combine losses
290
296
 
291
- if exists(mask):
292
- loss = loss[mask]
297
+ loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
293
298
 
299
+ loss = loss[target_mask]
294
300
  loss = loss.mean()
295
301
 
296
302
  if not return_loss_breakdown:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.21
3
+ Version: 1.28.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
7
+ x_transformers/x_transformers.py,sha256=GvqVKQZRtIldnSWX4V6qE2sWOGruRvBhk4MVit7ZD_M,63897
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
- x_transformers/xval.py,sha256=to8AkYfgvR5QC-ZXDnq8PPfzOyg0yutHt_QW36jsy98,8403
10
- x_transformers-1.27.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.27.21.dist-info/METADATA,sha256=g0OEysaFW4RomYhoqjDKu7W1brHg70m_N40_nmTLRtA,662
12
- x_transformers-1.27.21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.27.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.27.21.dist-info/RECORD,,
9
+ x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
+ x_transformers-1.28.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.28.0.dist-info/METADATA,sha256=o1AbarRMIJY_R0gNaEm5SNUWm3YHEesLL2EEy_Uk6gA,661
12
+ x_transformers-1.28.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.28.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.28.0.dist-info/RECORD,,