x-transformers 2.3.21__py3-none-any.whl → 2.3.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +14 -13
- x_transformers/autoregressive_wrapper.py +34 -5
- x_transformers/x_transformers.py +26 -0
- {x_transformers-2.3.21.dist-info → x_transformers-2.3.23.dist-info}/METADATA +1 -1
- {x_transformers-2.3.21.dist-info → x_transformers-2.3.23.dist-info}/RECORD +7 -7
- {x_transformers-2.3.21.dist-info → x_transformers-2.3.23.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.21.dist-info → x_transformers-2.3.23.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -276,21 +276,22 @@ class Attend(Module):
|
|
276
276
|
|
277
277
|
# torch 2.3 uses new backend and context manager
|
278
278
|
|
279
|
-
if
|
280
|
-
|
281
|
-
|
282
|
-
str_to_backend = dict(
|
283
|
-
enable_flash = SDPBackend.FLASH_ATTENTION,
|
284
|
-
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
285
|
-
enable_math = SDPBackend.MATH,
|
286
|
-
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
287
|
-
)
|
279
|
+
if self.flash:
|
280
|
+
if torch_version >= version.parse('2.3'):
|
281
|
+
from torch.nn.attention import SDPBackend
|
288
282
|
|
289
|
-
|
283
|
+
str_to_backend = dict(
|
284
|
+
enable_flash = SDPBackend.FLASH_ATTENTION,
|
285
|
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
286
|
+
enable_math = SDPBackend.MATH,
|
287
|
+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
288
|
+
)
|
290
289
|
|
291
|
-
|
292
|
-
|
293
|
-
|
290
|
+
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
291
|
+
|
292
|
+
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
293
|
+
else:
|
294
|
+
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
294
295
|
|
295
296
|
def flash_attn(
|
296
297
|
self,
|
@@ -140,7 +140,8 @@ class AutoregressiveWrapper(Module):
|
|
140
140
|
ignore_index = -100,
|
141
141
|
pad_value = 0,
|
142
142
|
mask_prob = 0.,
|
143
|
-
add_attn_z_loss = False
|
143
|
+
add_attn_z_loss = False,
|
144
|
+
next_embed_loss_weight = 0.1
|
144
145
|
):
|
145
146
|
super().__init__()
|
146
147
|
self.pad_value = pad_value
|
@@ -156,6 +157,10 @@ class AutoregressiveWrapper(Module):
|
|
156
157
|
# whether to add router z-loss
|
157
158
|
self.add_attn_z_loss = add_attn_z_loss
|
158
159
|
|
160
|
+
# whether to add a continuous loss
|
161
|
+
self.add_continuous_pred_head = net.add_continuous_pred_head
|
162
|
+
self.next_embed_loss_weight = next_embed_loss_weight
|
163
|
+
|
159
164
|
@torch.no_grad()
|
160
165
|
@eval_decorator
|
161
166
|
def generate(
|
@@ -305,9 +310,9 @@ class AutoregressiveWrapper(Module):
|
|
305
310
|
return out
|
306
311
|
|
307
312
|
def forward(self, x, return_outputs = False, **kwargs):
|
308
|
-
seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
|
313
|
+
seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
|
309
314
|
|
310
|
-
inp, target = x
|
315
|
+
inp, target = x, x[:, 1:]
|
311
316
|
inp = torch.where(inp == ignore_index, self.pad_value, inp)
|
312
317
|
|
313
318
|
if self.mask_prob > 0.:
|
@@ -318,15 +323,29 @@ class AutoregressiveWrapper(Module):
|
|
318
323
|
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
|
319
324
|
kwargs.update(self_attn_kv_mask = mask)
|
320
325
|
|
321
|
-
|
322
|
-
|
326
|
+
out, cache = self.net(
|
327
|
+
x,
|
323
328
|
return_intermediates = True,
|
324
329
|
return_attn_z_loss = add_attn_z_loss,
|
330
|
+
return_next_embed_pred = add_next_embed_loss,
|
325
331
|
**kwargs
|
326
332
|
)
|
327
333
|
|
334
|
+
# destruct differently if doing continuous pred
|
335
|
+
|
336
|
+
if add_next_embed_loss:
|
337
|
+
logits, (next_embed_pred, init_embeds) = out
|
338
|
+
else:
|
339
|
+
logits = out
|
340
|
+
|
341
|
+
logits = logits[:, :-1]
|
342
|
+
|
343
|
+
# loss function
|
344
|
+
|
328
345
|
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
329
346
|
|
347
|
+
# cross entropy loss
|
348
|
+
|
330
349
|
loss = loss_fn(
|
331
350
|
rearrange(logits, 'b n c -> b c n'),
|
332
351
|
target,
|
@@ -336,6 +355,16 @@ class AutoregressiveWrapper(Module):
|
|
336
355
|
if add_attn_z_loss:
|
337
356
|
loss = loss + cache.attn_z_loss
|
338
357
|
|
358
|
+
if add_next_embed_loss:
|
359
|
+
mask = inp[:, :-1] != ignore_index
|
360
|
+
embed_pred = next_embed_pred[:, :-1]
|
361
|
+
cont_targets = init_embeds[:, 1:].detach()
|
362
|
+
|
363
|
+
cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
|
364
|
+
cont_loss = cont_loss[mask].mean()
|
365
|
+
|
366
|
+
loss = loss + cont_loss * self.next_embed_loss_weight
|
367
|
+
|
339
368
|
if not return_outputs:
|
340
369
|
return loss
|
341
370
|
|
x_transformers/x_transformers.py
CHANGED
@@ -2855,6 +2855,7 @@ class TransformerWrapper(Module):
|
|
2855
2855
|
sigsoftmax_logits = False,
|
2856
2856
|
ff_deep_embed = False,
|
2857
2857
|
to_logits: Module | None = None,
|
2858
|
+
add_continuous_pred_head = False
|
2858
2859
|
):
|
2859
2860
|
super().__init__()
|
2860
2861
|
|
@@ -2975,6 +2976,18 @@ class TransformerWrapper(Module):
|
|
2975
2976
|
else:
|
2976
2977
|
self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
|
2977
2978
|
|
2979
|
+
# add a head that predicts the embedding of the next step
|
2980
|
+
|
2981
|
+
self.add_continuous_pred_head = add_continuous_pred_head
|
2982
|
+
|
2983
|
+
if add_continuous_pred_head:
|
2984
|
+
|
2985
|
+
self.to_next_embed_pred = nn.Sequential(
|
2986
|
+
LinearNoBias(dim, dim),
|
2987
|
+
nn.SiLU(),
|
2988
|
+
LinearNoBias(dim, dim)
|
2989
|
+
)
|
2990
|
+
|
2978
2991
|
# memory tokens (like [cls]) from Memory Transformers paper
|
2979
2992
|
|
2980
2993
|
num_memory_tokens = default(num_memory_tokens, 0)
|
@@ -3009,6 +3022,7 @@ class TransformerWrapper(Module):
|
|
3009
3022
|
return_intermediates = False,
|
3010
3023
|
return_embeddings_and_intermediates = False,
|
3011
3024
|
return_logit_entropies = False,
|
3025
|
+
return_next_embed_pred = False,
|
3012
3026
|
mask = None,
|
3013
3027
|
return_mems = False,
|
3014
3028
|
return_attn = False,
|
@@ -3100,6 +3114,10 @@ class TransformerWrapper(Module):
|
|
3100
3114
|
assert emb_frac_gradient > 0
|
3101
3115
|
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
|
3102
3116
|
|
3117
|
+
# init embed
|
3118
|
+
|
3119
|
+
init_embed = x
|
3120
|
+
|
3103
3121
|
# embedding dropout
|
3104
3122
|
|
3105
3123
|
x = self.emb_dropout(x)
|
@@ -3261,6 +3279,14 @@ class TransformerWrapper(Module):
|
|
3261
3279
|
else:
|
3262
3280
|
out = logits
|
3263
3281
|
|
3282
|
+
# maybe next embed pred
|
3283
|
+
|
3284
|
+
if return_next_embed_pred:
|
3285
|
+
assert self.add_continuous_pred_head
|
3286
|
+
next_embed_out = self.to_next_embed_pred(x)
|
3287
|
+
|
3288
|
+
out = (out, (next_embed_out, init_embed))
|
3289
|
+
|
3264
3290
|
# logit entropies
|
3265
3291
|
|
3266
3292
|
if return_logit_entropies:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
|
-
x_transformers/attend.py,sha256=
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=BWFaO-3YWzCcEfp-EC1ZkdckqDpPIOQG6_uyyP6AmhM,11753
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.23.dist-info/METADATA,sha256=xRMZP1TSYdcbc0F5GX-WcaHhAbQPdGeFIbjHBZYG9_0,89897
|
15
|
+
x_transformers-2.3.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.23.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|