x-transformers 1.27.21__py3-none-any.whl → 1.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -3
- x_transformers/xval.py +16 -10
- {x_transformers-1.27.21.dist-info → x_transformers-1.28.0.dist-info}/METADATA +1 -1
- {x_transformers-1.27.21.dist-info → x_transformers-1.28.0.dist-info}/RECORD +7 -7
- {x_transformers-1.27.21.dist-info → x_transformers-1.28.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.21.dist-info → x_transformers-1.28.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.21.dist-info → x_transformers-1.28.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -889,7 +889,15 @@ class Attention(nn.Module):
|
|
889
889
|
else:
|
890
890
|
input_mask = torch.cat((mem_mask, input_mask), dim = -1)
|
891
891
|
|
892
|
-
|
892
|
+
# i, j determined for relative positional bias, excluding memory key / values
|
893
|
+
|
894
|
+
i, j = map(lambda t: t.shape[-2], (q, k))
|
895
|
+
|
896
|
+
# maybe append memory key / values
|
897
|
+
|
898
|
+
has_mem_kv = self.num_mem_kv > 0
|
899
|
+
|
900
|
+
if has_mem_kv:
|
893
901
|
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
|
894
902
|
|
895
903
|
if self.qk_norm:
|
@@ -902,8 +910,6 @@ class Attention(nn.Module):
|
|
902
910
|
if exists(input_mask):
|
903
911
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
904
912
|
|
905
|
-
i, j = map(lambda t: t.shape[-2], (q, k))
|
906
|
-
|
907
913
|
# determine masking
|
908
914
|
|
909
915
|
mask_value = max_neg_value(q)
|
@@ -938,6 +944,11 @@ class Attention(nn.Module):
|
|
938
944
|
if exists(rel_pos):
|
939
945
|
attn_bias = rel_pos(i, j)
|
940
946
|
|
947
|
+
# append with no bias for memory key / values
|
948
|
+
|
949
|
+
if has_mem_kv:
|
950
|
+
attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
|
951
|
+
|
941
952
|
# attention is all we need
|
942
953
|
|
943
954
|
out, intermediates = self.attend(
|
x_transformers/xval.py
CHANGED
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
260
260
|
inp, target = x[:, :-1], x[:, 1:]
|
261
261
|
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
|
262
262
|
|
263
|
+
# ignore index
|
264
|
+
|
265
|
+
target_mask = target != self.ignore_index
|
266
|
+
|
267
|
+
# key padding mask
|
268
|
+
|
263
269
|
mask = kwargs.get('mask', None)
|
264
|
-
if exists(mask)
|
265
|
-
|
266
|
-
|
270
|
+
if exists(mask):
|
271
|
+
target_mask &= mask
|
272
|
+
|
273
|
+
if mask.shape[1] == x.shape[1]:
|
274
|
+
mask = mask[:, :-1]
|
275
|
+
kwargs['mask'] = mask
|
267
276
|
|
268
277
|
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
|
269
278
|
|
@@ -276,21 +285,18 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
276
285
|
target_is_number_mask = target == self.net.numerical_token_id
|
277
286
|
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
|
278
287
|
|
279
|
-
# ignore index
|
280
|
-
|
281
|
-
target_mask = target != self.ignore_index
|
282
|
-
|
283
288
|
# numerical mse loss
|
284
289
|
|
285
290
|
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
|
286
291
|
|
287
292
|
numerical_mse_loss = numerical_mse_loss * target_mask
|
293
|
+
numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
|
288
294
|
|
289
|
-
|
295
|
+
# combine losses
|
290
296
|
|
291
|
-
|
292
|
-
loss = loss[mask]
|
297
|
+
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
|
293
298
|
|
299
|
+
loss = loss[target_mask]
|
294
300
|
loss = loss.mean()
|
295
301
|
|
296
302
|
if not return_loss_breakdown:
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=GvqVKQZRtIldnSWX4V6qE2sWOGruRvBhk4MVit7ZD_M,63897
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
-
x_transformers/xval.py,sha256=
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
9
|
+
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
+
x_transformers-1.28.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.28.0.dist-info/METADATA,sha256=o1AbarRMIJY_R0gNaEm5SNUWm3YHEesLL2EEy_Uk6gA,661
|
12
|
+
x_transformers-1.28.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.28.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.28.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|