x-transformers 2.1.27__py3-none-any.whl → 2.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +4 -2
- {x_transformers-2.1.27.dist-info → x_transformers-2.1.28.dist-info}/METADATA +1 -1
- {x_transformers-2.1.27.dist-info → x_transformers-2.1.28.dist-info}/RECORD +5 -5
- {x_transformers-2.1.27.dist-info → x_transformers-2.1.28.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.27.dist-info → x_transformers-2.1.28.dist-info}/licenses/LICENSE +0 -0
@@ -260,6 +260,7 @@ class BeliefStateWrapper(Module):
|
|
260
260
|
|
261
261
|
if exists(lens):
|
262
262
|
mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
|
263
|
+
seq_for_labels = torch.where(mask, seq, -1)
|
263
264
|
|
264
265
|
# forward autoregressive
|
265
266
|
|
@@ -319,7 +320,7 @@ class BeliefStateWrapper(Module):
|
|
319
320
|
|
320
321
|
labels_fi, labels_bi = (fi + 1), (bi - 1)
|
321
322
|
|
322
|
-
forward_labels, backward_labels =
|
323
|
+
forward_labels, backward_labels = seq_for_labels[:, labels_fi], seq_for_labels[:, labels_bi]
|
323
324
|
|
324
325
|
labels = cat((forward_labels, backward_labels), dim = -1)
|
325
326
|
|
@@ -337,7 +338,8 @@ class BeliefStateWrapper(Module):
|
|
337
338
|
loss = F.cross_entropy(
|
338
339
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
339
340
|
labels,
|
340
|
-
reduction = 'none' if self.needs_loss_weight else 'mean'
|
341
|
+
reduction = 'none' if self.needs_loss_weight else 'mean',
|
342
|
+
ignore_index = -1
|
341
343
|
)
|
342
344
|
|
343
345
|
# maybe predict terminal
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=nx7NdEZQ98Puz1RwAl7wThFJ_R8xLpUbwoqYjb6IF28,12508
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.28.dist-info/METADATA,sha256=9VQXzWtJjNhONmS9sSxM4DQrJZJok1TgkUN0q8eT-S0,87875
|
14
|
+
x_transformers-2.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.28.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|