x-transformers 2.3.18__py3-none-any.whl → 2.3.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -25,6 +25,7 @@ class Intermediates:
25
25
  values: Tensor | None = None
26
26
  cached_kv: Tuple[Tensor, Tensor] | None = None
27
27
  layer_type: str | None = None
28
+ hybrid_hidden: Tensor | None = None
28
29
 
29
30
  def to_tuple(self):
30
31
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -32,6 +32,15 @@ def default(val, d):
32
32
  return val
33
33
  return d() if not isinstance(d, Module) and callable(d) else d
34
34
 
35
+ def sample_from_mean_variance(
36
+ mean,
37
+ variance,
38
+ eps = 1e-5,
39
+ temperature = 1.
40
+ ):
41
+ std = variance.clamp(min = eps).sqrt()
42
+ return torch.normal(mean, std * temperature)
43
+
35
44
  def masked_mean(t, mask):
36
45
  t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
37
46
 
@@ -274,9 +283,7 @@ class ContinuousAutoregressiveWrapper(Module):
274
283
 
275
284
  if self.probabilistic:
276
285
  mean, var = last_output
277
- stddev = var.clamp(min = 1e-5).sqrt()
278
-
279
- last_output = torch.normal(mean, stddev * temperature)
286
+ last_output = sample_from_mean_variance(mean, var, temperature = temperature)
280
287
 
281
288
  out = cat((out, last_output), dim = -2)
282
289
 
@@ -372,8 +379,7 @@ class ContinuousAutoregressiveWrapper(Module):
372
379
 
373
380
  if self.probabilistic:
374
381
  mean, var = last_pred
375
- std = var.clamp(min = 1e-5).sqrt()
376
- inp = torch.normal(mean, std)
382
+ inp = sample_from_mean_variance(mean, var)
377
383
  else:
378
384
  inp = last_pred
379
385
 
@@ -49,6 +49,7 @@ class LayerIntermediates:
49
49
  mems: Tensor | None = None
50
50
  memory_tokens: Tensor | None = None
51
51
  logit_entropies: Tensor | None = None
52
+ cache_length: int = 0
52
53
 
53
54
  LinearNoBias = partial(nn.Linear, bias = False)
54
55
 
@@ -282,12 +283,18 @@ class AbsolutePositionalEmbedding(Module):
282
283
  self.l2norm_embed = l2norm_embed
283
284
  self.emb = nn.Embedding(max_seq_len, dim)
284
285
 
285
- def forward(self, x, pos = None, seq_start_pos = None):
286
+ def forward(
287
+ self,
288
+ x,
289
+ pos = None,
290
+ seq_start_pos = None,
291
+ offset = 0
292
+ ):
286
293
  seq_len, device = x.shape[1], x.device
287
294
  assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
288
295
 
289
296
  if not exists(pos):
290
- pos = arange(seq_len, device = device)
297
+ pos = arange(seq_len, device = device) + offset
291
298
 
292
299
  if exists(seq_start_pos):
293
300
  pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
@@ -307,11 +314,17 @@ class ScaledSinusoidalEmbedding(Module):
307
314
  inv_freq = theta ** -freq_seq
308
315
  self.register_buffer('inv_freq', inv_freq, persistent = False)
309
316
 
310
- def forward(self, x, pos = None, seq_start_pos = None):
317
+ def forward(
318
+ self,
319
+ x,
320
+ pos = None,
321
+ seq_start_pos = None,
322
+ offset = 0
323
+ ):
311
324
  seq_len, device = x.shape[1], x.device
312
325
 
313
326
  if not exists(pos):
314
- pos = arange(seq_len, device = device)
327
+ pos = arange(seq_len, device = device) + offset
315
328
 
316
329
  if exists(seq_start_pos):
317
330
  pos = pos - seq_start_pos[..., None]
@@ -676,7 +689,7 @@ class RotaryEmbedding(Module):
676
689
  return self.forward(t)
677
690
 
678
691
  @autocast('cuda', enabled = False)
679
- def forward(self, t):
692
+ def forward(self, t, offset = 0):
680
693
  max_pos = t.max() + 1
681
694
 
682
695
  if t.ndim == 1:
@@ -1079,10 +1092,11 @@ class FoldAxially(Module):
1079
1092
  def forward(
1080
1093
  self,
1081
1094
  x,
1095
+ *args,
1082
1096
  **kwargs
1083
1097
  ):
1084
1098
  if self.axial_dim == 1:
1085
- return self.fn(x, **kwargs)
1099
+ return self.fn(x, *args, **kwargs)
1086
1100
 
1087
1101
  seq_len, axial_dim = x.shape[1], self.axial_dim
1088
1102
 
@@ -1091,7 +1105,7 @@ class FoldAxially(Module):
1091
1105
 
1092
1106
  x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
1093
1107
 
1094
- out = self.fn(x, **kwargs)
1108
+ out = self.fn(x, *args, **kwargs)
1095
1109
 
1096
1110
  (out, *rest_out), tree_spec = tree_flatten(out)
1097
1111
 
@@ -1857,9 +1871,17 @@ class Attention(Module):
1857
1871
  if not self.causal and exists(self.hybrid_mask_kwarg):
1858
1872
  hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1859
1873
 
1874
+ # handle maybe hybrid cache
1875
+
1876
+ hybrid_forward_args = ()
1877
+
1878
+ if exists(cache) and exists(cache.hybrid_hidden):
1879
+ hybrid_hiddens = cache.hybrid_hidden
1880
+ hybrid_forward_args = (hybrid_hiddens,)
1881
+
1860
1882
  # hybrid forward
1861
1883
 
1862
- hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1884
+ hybrid_outputs = self.hybrid_module(x, *hybrid_forward_args, **hybrid_forward_kwargs)
1863
1885
 
1864
1886
  # handle hybrid out
1865
1887
 
@@ -1870,6 +1892,10 @@ class Attention(Module):
1870
1892
  if hybrid_out.ndim == 3:
1871
1893
  hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1872
1894
 
1895
+ if len(rest_hybrid_outs) > 0:
1896
+ hybrid_hidden = first(rest_hybrid_outs)
1897
+ intermediates.hybrid_hidden = hybrid_hidden
1898
+
1873
1899
  out_norm, hybrid_out_norm = self.hybrid_norms
1874
1900
 
1875
1901
  out = out_norm(out)
@@ -2360,7 +2386,9 @@ class AttentionLayers(Module):
2360
2386
  mems = None,
2361
2387
  mem_masks = None,
2362
2388
  seq_start_pos: Tensor | None = None,
2389
+ seq_pos_offset: int = 0,
2363
2390
  cache: LayerIntermediates | None = None,
2391
+ input_not_include_cache = False,
2364
2392
  cache_age = 1,
2365
2393
  return_hiddens = False,
2366
2394
  rotary_pos_emb = None,
@@ -2434,7 +2462,7 @@ class AttentionLayers(Module):
2434
2462
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
2435
2463
 
2436
2464
  if not exists(pos):
2437
- pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
2465
+ pos = arange(x.shape[1] + mem_len + seq_pos_offset, device = x.device) - mem_len
2438
2466
 
2439
2467
  rotary_pos_emb = self.rotary_pos_emb(pos)
2440
2468
 
@@ -2451,11 +2479,15 @@ class AttentionLayers(Module):
2451
2479
 
2452
2480
  # assume cached key / values
2453
2481
 
2482
+ prev_cache_length = 0
2483
+
2454
2484
  attn_cache = []
2455
2485
 
2456
2486
  if exists(cache):
2457
2487
  assert self.causal and not any([*map(exists, (mask, attn_mask))])
2458
2488
 
2489
+ prev_cache_length = cache.cache_length
2490
+
2459
2491
  if exists(context):
2460
2492
  context = context[:, :0]
2461
2493
 
@@ -2469,6 +2501,8 @@ class AttentionLayers(Module):
2469
2501
 
2470
2502
  attn_cache = cache.attn_intermediates
2471
2503
 
2504
+ next_cache_length = x.shape[1]
2505
+
2472
2506
  iter_attn_cache = iter(attn_cache)
2473
2507
 
2474
2508
  # handle deep embeds if needed
@@ -2655,6 +2689,7 @@ class AttentionLayers(Module):
2655
2689
  last_hidden = x,
2656
2690
  attn_intermediates = intermediates,
2657
2691
  layer_hiddens = layer_hiddens,
2692
+ cache_length = next_cache_length + prev_cache_length
2658
2693
  )
2659
2694
 
2660
2695
  return x, intermediates
@@ -2989,6 +3024,7 @@ class TransformerWrapper(Module):
2989
3024
  attn_z_loss_weight = 1e-4,
2990
3025
  seq_start_pos = None,
2991
3026
  cache: LayerIntermediates | None = None,
3027
+ input_not_include_cache = False,
2992
3028
  token_emb_kwargs = dict(),
2993
3029
  to_logits_kwargs = dict(),
2994
3030
  **kwargs,
@@ -3007,10 +3043,17 @@ class TransformerWrapper(Module):
3007
3043
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
3008
3044
  return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
3009
3045
 
3046
+ # take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
3047
+
3048
+ seq_pos_offset = 0
3049
+
3050
+ if exists(cache) and input_not_include_cache:
3051
+ seq_pos_offset = cache.cache_length
3052
+
3010
3053
  # absolute positional embedding
3011
3054
 
3012
3055
  external_pos_emb = exists(pos) and pos.dtype != torch.long
3013
- pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
3056
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset) if not external_pos_emb else pos
3014
3057
  x = self.token_emb(x, **token_emb_kwargs) + pos_emb
3015
3058
 
3016
3059
  # add additional embeddings
@@ -3109,6 +3152,15 @@ class TransformerWrapper(Module):
3109
3152
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
3110
3153
  mems = [*mems_r, *mems_l]
3111
3154
 
3155
+ # attn layers kwargs
3156
+
3157
+ kwargs = dict(
3158
+ **kwargs,
3159
+ seq_pos_offset = seq_pos_offset,
3160
+ seq_start_pos = seq_start_pos,
3161
+ input_not_include_cache = input_not_include_cache
3162
+ )
3163
+
3112
3164
  # attention layers
3113
3165
 
3114
3166
  if not self.recycling:
@@ -3116,7 +3168,7 @@ class TransformerWrapper(Module):
3116
3168
 
3117
3169
  # regular
3118
3170
 
3119
- attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3171
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, **kwargs)
3120
3172
 
3121
3173
  else:
3122
3174
  # recycling
@@ -3133,7 +3185,7 @@ class TransformerWrapper(Module):
3133
3185
  with context():
3134
3186
  maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
3135
3187
 
3136
- attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3188
+ attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, **kwargs)
3137
3189
 
3138
3190
  x = attended
3139
3191
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.18
3
+ Version: 2.3.20
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,17 +1,17 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
- x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
2
+ x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=uV2hLQOckeRsybqJy-0F8RhAyMPJlkVHmA7QqUJHG4g,12433
5
+ x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
11
+ x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.18.dist-info/METADATA,sha256=RKXNlO50fifu1Nas38iZRn6IJVDkv4Cen94XYVJlWg0,89897
15
- x_transformers-2.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.18.dist-info/RECORD,,
14
+ x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
15
+ x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.20.dist-info/RECORD,,