x-transformers 2.3.25__py3-none-any.whl → 2.3.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +16 -1
- x_transformers/x_transformers.py +2 -2
- {x_transformers-2.3.25.dist-info → x_transformers-2.3.27.dist-info}/METADATA +1 -1
- {x_transformers-2.3.25.dist-info → x_transformers-2.3.27.dist-info}/RECORD +6 -6
- {x_transformers-2.3.25.dist-info → x_transformers-2.3.27.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.25.dist-info → x_transformers-2.3.27.dist-info}/licenses/LICENSE +0 -0
@@ -309,7 +309,13 @@ class AutoregressiveWrapper(Module):
|
|
309
309
|
|
310
310
|
return out
|
311
311
|
|
312
|
-
def forward(
|
312
|
+
def forward(
|
313
|
+
self,
|
314
|
+
x,
|
315
|
+
return_outputs = False,
|
316
|
+
prepend_embeds = None,
|
317
|
+
**kwargs
|
318
|
+
):
|
313
319
|
seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
|
314
320
|
|
315
321
|
inp, target = x, x[:, 1:]
|
@@ -328,6 +334,7 @@ class AutoregressiveWrapper(Module):
|
|
328
334
|
return_intermediates = True,
|
329
335
|
return_attn_z_loss = add_attn_z_loss,
|
330
336
|
return_next_embed_pred = add_next_embed_loss,
|
337
|
+
prepend_embeds = prepend_embeds,
|
331
338
|
**kwargs
|
332
339
|
)
|
333
340
|
|
@@ -338,6 +345,14 @@ class AutoregressiveWrapper(Module):
|
|
338
345
|
else:
|
339
346
|
logits = out
|
340
347
|
|
348
|
+
# if there are prepended embeds, excise it out
|
349
|
+
|
350
|
+
if exists(prepend_embeds):
|
351
|
+
prepend_len = prepend_embeds.shape[1]
|
352
|
+
logits = logits[:, prepend_len:]
|
353
|
+
|
354
|
+
# take all tokens but the last
|
355
|
+
|
341
356
|
logits = logits[:, :-1]
|
342
357
|
|
343
358
|
# loss function
|
x_transformers/x_transformers.py
CHANGED
@@ -1926,7 +1926,7 @@ class Attention(Module):
|
|
1926
1926
|
|
1927
1927
|
out = maybe(self.sublayer_dropout)(out)
|
1928
1928
|
|
1929
|
-
if exists(mask):
|
1929
|
+
if exists(mask) and not exists(cache):
|
1930
1930
|
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1931
1931
|
|
1932
1932
|
if not return_intermediates:
|
@@ -2484,7 +2484,7 @@ class AttentionLayers(Module):
|
|
2484
2484
|
attn_cache = []
|
2485
2485
|
|
2486
2486
|
if exists(cache):
|
2487
|
-
assert self.causal and not
|
2487
|
+
assert self.causal and not exists(attn_mask)
|
2488
2488
|
|
2489
2489
|
prev_cache_length = cache.cache_length
|
2490
2490
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=3tDUiY5kNcxNUjRERoeuFV0mXztOvgGrckoACIfHvqI,12091
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=g7y9U48sirVN6oFq_XxPUDhqKO0U8pdmLYcbT0CoH1E,116223
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.27.dist-info/METADATA,sha256=UNVupcXx-VDnWW5sRWJ4WlxOvUtwDDAy0Lig6s5xG0I,89897
|
15
|
+
x_transformers-2.3.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.27.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|