x-transformers 2.3.21__py3-none-any.whl → 2.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -276,21 +276,22 @@ class Attend(Module):
276
276
 
277
277
  # torch 2.3 uses new backend and context manager
278
278
 
279
- if torch_version >= version.parse('2.3'):
280
- from torch.nn.attention import SDPBackend
281
-
282
- str_to_backend = dict(
283
- enable_flash = SDPBackend.FLASH_ATTENTION,
284
- enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
285
- enable_math = SDPBackend.MATH,
286
- enable_cudnn = SDPBackend.CUDNN_ATTENTION
287
- )
279
+ if self.flash:
280
+ if torch_version >= version.parse('2.3'):
281
+ from torch.nn.attention import SDPBackend
288
282
 
289
- sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
283
+ str_to_backend = dict(
284
+ enable_flash = SDPBackend.FLASH_ATTENTION,
285
+ enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
286
+ enable_math = SDPBackend.MATH,
287
+ enable_cudnn = SDPBackend.CUDNN_ATTENTION
288
+ )
290
289
 
291
- self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
292
- else:
293
- self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
290
+ sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
291
+
292
+ self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
293
+ else:
294
+ self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
294
295
 
295
296
  def flash_attn(
296
297
  self,
@@ -140,7 +140,8 @@ class AutoregressiveWrapper(Module):
140
140
  ignore_index = -100,
141
141
  pad_value = 0,
142
142
  mask_prob = 0.,
143
- add_attn_z_loss = False
143
+ add_attn_z_loss = False,
144
+ next_embed_loss_weight = 0.1
144
145
  ):
145
146
  super().__init__()
146
147
  self.pad_value = pad_value
@@ -156,6 +157,10 @@ class AutoregressiveWrapper(Module):
156
157
  # whether to add router z-loss
157
158
  self.add_attn_z_loss = add_attn_z_loss
158
159
 
160
+ # whether to add a continuous loss
161
+ self.add_continuous_pred_head = net.add_continuous_pred_head
162
+ self.next_embed_loss_weight = next_embed_loss_weight
163
+
159
164
  @torch.no_grad()
160
165
  @eval_decorator
161
166
  def generate(
@@ -305,9 +310,9 @@ class AutoregressiveWrapper(Module):
305
310
  return out
306
311
 
307
312
  def forward(self, x, return_outputs = False, **kwargs):
308
- seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
313
+ seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
309
314
 
310
- inp, target = x[:, :-1], x[:, 1:]
315
+ inp, target = x, x[:, 1:]
311
316
  inp = torch.where(inp == ignore_index, self.pad_value, inp)
312
317
 
313
318
  if self.mask_prob > 0.:
@@ -318,15 +323,29 @@ class AutoregressiveWrapper(Module):
318
323
  mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
319
324
  kwargs.update(self_attn_kv_mask = mask)
320
325
 
321
- logits, cache = self.net(
322
- inp,
326
+ out, cache = self.net(
327
+ x,
323
328
  return_intermediates = True,
324
329
  return_attn_z_loss = add_attn_z_loss,
330
+ return_next_embed_pred = add_next_embed_loss,
325
331
  **kwargs
326
332
  )
327
333
 
334
+ # destruct differently if doing continuous pred
335
+
336
+ if add_next_embed_loss:
337
+ logits, (next_embed_pred, init_embeds) = out
338
+ else:
339
+ logits = out
340
+
341
+ logits = logits[:, :-1]
342
+
343
+ # loss function
344
+
328
345
  loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
329
346
 
347
+ # cross entropy loss
348
+
330
349
  loss = loss_fn(
331
350
  rearrange(logits, 'b n c -> b c n'),
332
351
  target,
@@ -336,6 +355,16 @@ class AutoregressiveWrapper(Module):
336
355
  if add_attn_z_loss:
337
356
  loss = loss + cache.attn_z_loss
338
357
 
358
+ if add_next_embed_loss:
359
+ mask = inp[:, :-1] != ignore_index
360
+ embed_pred = next_embed_pred[:, :-1]
361
+ cont_targets = init_embeds[:, 1:].detach()
362
+
363
+ cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
364
+ cont_loss = cont_loss[mask].mean()
365
+
366
+ loss = loss + cont_loss * self.next_embed_loss_weight
367
+
339
368
  if not return_outputs:
340
369
  return loss
341
370
 
@@ -2855,6 +2855,7 @@ class TransformerWrapper(Module):
2855
2855
  sigsoftmax_logits = False,
2856
2856
  ff_deep_embed = False,
2857
2857
  to_logits: Module | None = None,
2858
+ add_continuous_pred_head = False
2858
2859
  ):
2859
2860
  super().__init__()
2860
2861
 
@@ -2975,6 +2976,18 @@ class TransformerWrapper(Module):
2975
2976
  else:
2976
2977
  self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
2977
2978
 
2979
+ # add a head that predicts the embedding of the next step
2980
+
2981
+ self.add_continuous_pred_head = add_continuous_pred_head
2982
+
2983
+ if add_continuous_pred_head:
2984
+
2985
+ self.to_next_embed_pred = nn.Sequential(
2986
+ LinearNoBias(dim, dim),
2987
+ nn.SiLU(),
2988
+ LinearNoBias(dim, dim)
2989
+ )
2990
+
2978
2991
  # memory tokens (like [cls]) from Memory Transformers paper
2979
2992
 
2980
2993
  num_memory_tokens = default(num_memory_tokens, 0)
@@ -3009,6 +3022,7 @@ class TransformerWrapper(Module):
3009
3022
  return_intermediates = False,
3010
3023
  return_embeddings_and_intermediates = False,
3011
3024
  return_logit_entropies = False,
3025
+ return_next_embed_pred = False,
3012
3026
  mask = None,
3013
3027
  return_mems = False,
3014
3028
  return_attn = False,
@@ -3100,6 +3114,10 @@ class TransformerWrapper(Module):
3100
3114
  assert emb_frac_gradient > 0
3101
3115
  x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
3102
3116
 
3117
+ # init embed
3118
+
3119
+ init_embed = x
3120
+
3103
3121
  # embedding dropout
3104
3122
 
3105
3123
  x = self.emb_dropout(x)
@@ -3261,6 +3279,14 @@ class TransformerWrapper(Module):
3261
3279
  else:
3262
3280
  out = logits
3263
3281
 
3282
+ # maybe next embed pred
3283
+
3284
+ if return_next_embed_pred:
3285
+ assert self.add_continuous_pred_head
3286
+ next_embed_out = self.to_next_embed_pred(x)
3287
+
3288
+ out = (out, (next_embed_out, init_embed))
3289
+
3264
3290
  # logit entropies
3265
3291
 
3266
3292
  if return_logit_entropies:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.21
3
+ Version: 2.3.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
- x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
3
- x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
2
+ x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
+ x_transformers/autoregressive_wrapper.py,sha256=BWFaO-3YWzCcEfp-EC1ZkdckqDpPIOQG6_uyyP6AmhM,11753
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
11
+ x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.21.dist-info/METADATA,sha256=530_RGFGFlDyKIV6vMGqjGGw0f3gpArBbwNBHai_LQs,89897
15
- x_transformers-2.3.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.21.dist-info/RECORD,,
14
+ x_transformers-2.3.23.dist-info/METADATA,sha256=xRMZP1TSYdcbc0F5GX-WcaHhAbQPdGeFIbjHBZYG9_0,89897
15
+ x_transformers-2.3.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.23.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.23.dist-info/RECORD,,