x-transformers 1.30.2__py3-none-any.whl → 1.30.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -37,6 +37,9 @@ def default(val, d):
37
37
  def compact(arr):
38
38
  return [*filter(exists, arr)]
39
39
 
40
+ def softclamp(t, value):
41
+ return (t / value).tanh() * value
42
+
40
43
  def once(fn):
41
44
  called = False
42
45
  @wraps(fn)
@@ -76,6 +79,7 @@ class Attend(nn.Module):
76
79
  scale = None,
77
80
  qk_norm = False,
78
81
  flash = False,
82
+ logit_softclamp_value = None,
79
83
  add_zero_kv = False,
80
84
  onnxable = False,
81
85
  sdp_kwargs: dict = dict(
@@ -114,6 +118,14 @@ class Attend(nn.Module):
114
118
 
115
119
  self.add_zero_kv = add_zero_kv
116
120
 
121
+ # soft clamp attention logit value
122
+
123
+ if exists(logit_softclamp_value):
124
+ assert not flash, 'flash attention not compatible with logit softclamp value yet'
125
+ assert logit_softclamp_value > 0.
126
+
127
+ self.logit_softclamp_value = logit_softclamp_value
128
+
117
129
  # flash attention
118
130
 
119
131
  self.flash = flash
@@ -276,38 +288,41 @@ class Attend(nn.Module):
276
288
 
277
289
  kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
278
290
 
279
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
291
+ sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
280
292
 
281
293
  if exists(prev_attn):
282
- dots = dots + prev_attn
294
+ sim = sim + prev_attn
283
295
 
284
- qk_similarities = dots.clone()
296
+ qk_similarities = sim.clone()
285
297
 
286
298
  if self.talking_heads:
287
- dots = self.pre_softmax_talking_heads(dots)
299
+ sim = self.pre_softmax_talking_heads(sim)
288
300
 
289
301
  if exists(attn_bias):
290
- dots = dots + attn_bias
302
+ sim = sim + attn_bias
291
303
 
292
- i, j, dtype = *dots.shape[-2:], dots.dtype
304
+ i, j, dtype = *sim.shape[-2:], sim.dtype
293
305
 
294
- mask_value = -torch.finfo(dots.dtype).max
306
+ mask_value = -torch.finfo(sim.dtype).max
295
307
 
296
308
  if exists(self.sparse_topk) and self.sparse_topk < j:
297
- top_values, _ = dots.topk(self.sparse_topk, dim = -1)
298
- sparse_topk_mask = dots < top_values[..., -1:]
309
+ top_values, _ = sim.topk(self.sparse_topk, dim = -1)
310
+ sparse_topk_mask = sim < top_values[..., -1:]
299
311
  mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
300
312
 
301
313
  if exists(mask):
302
- dots = dots.masked_fill(~mask, mask_value)
314
+ sim = sim.masked_fill(~mask, mask_value)
303
315
 
304
316
  if causal:
305
317
  causal_mask = self.create_causal_mask(i, j, device = device)
306
- dots = dots.masked_fill(causal_mask, mask_value)
318
+ sim = sim.masked_fill(causal_mask, mask_value)
319
+
320
+ pre_softmax_attn = sim.clone()
307
321
 
308
- pre_softmax_attn = dots.clone()
322
+ if exists(self.logit_softclamp_value):
323
+ sim = softclamp(sim, self.logit_softclamp_value)
309
324
 
310
- attn = self.attn_fn(dots, dim = -1)
325
+ attn = self.attn_fn(sim, dim = -1)
311
326
  attn = attn.type(dtype)
312
327
 
313
328
  post_softmax_attn = attn.clone()
@@ -722,6 +722,7 @@ class Attention(Module):
722
722
  tensor_product = False, # https://arxiv.org/abs/2208.06061
723
723
  add_zero_kv = False, # same as add_zero_attn in pytorch
724
724
  rotary_embed_values = False,
725
+ logit_softclamp_value = None,
725
726
  onnxable = False
726
727
  ):
727
728
  super().__init__()
@@ -801,6 +802,7 @@ class Attention(Module):
801
802
  scale = qk_norm_scale if qk_norm else self.scale,
802
803
  add_zero_kv = add_zero_kv,
803
804
  flash = flash,
805
+ logit_softclamp_value = logit_softclamp_value,
804
806
  onnxable = onnxable
805
807
  )
806
808
 
@@ -866,7 +868,7 @@ class Attention(Module):
866
868
 
867
869
  k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
868
870
 
869
- if exists(cache) and not has_context:
871
+ if exists(cache):
870
872
  ck, cv = cache.cached_kv
871
873
 
872
874
  if exists(mem):
@@ -1336,6 +1338,9 @@ class AttentionLayers(Module):
1336
1338
  if exists(cache):
1337
1339
  assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1338
1340
 
1341
+ if exists(context):
1342
+ context = context[:, :0]
1343
+
1339
1344
  if cache_age > 0:
1340
1345
  x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1341
1346
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.2
3
+ Version: 1.30.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
2
+ x_transformers/attend.py,sha256=2SPHjXS_QAAZt04lHWGtdOypTExmo3BrbFhgcIQTk-Y,10671
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=uwooxffSXL2vTxLhDnkxF7fMe0gaCFW5WinuiR0fQpU,66191
7
+ x_transformers/x_transformers.py,sha256=P4rqlYGS9j9Gz00B4NPM7L6mhvamSYdBy5nG0ggOIMM,66342
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.2.dist-info/METADATA,sha256=ih2I-SzJQe_qASq_WzOwNuGKKrNTe5mvTba0ZrnMdfI,661
12
- x_transformers-1.30.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.2.dist-info/RECORD,,
10
+ x_transformers-1.30.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.4.dist-info/METADATA,sha256=VwdrJaRjocQXIAxdGzq4rByPGvaA4jsogostzCysdjI,661
12
+ x_transformers-1.30.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.4.dist-info/RECORD,,