x-transformers 1.30.2__py3-none-any.whl → 1.30.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -37,6 +37,9 @@ def default(val, d):
37
37
  def compact(arr):
38
38
  return [*filter(exists, arr)]
39
39
 
40
+ def softclamp(t, value):
41
+ return (t / value).tanh() * value
42
+
40
43
  def once(fn):
41
44
  called = False
42
45
  @wraps(fn)
@@ -76,6 +79,7 @@ class Attend(nn.Module):
76
79
  scale = None,
77
80
  qk_norm = False,
78
81
  flash = False,
82
+ logit_softclamp_value = None,
79
83
  add_zero_kv = False,
80
84
  onnxable = False,
81
85
  sdp_kwargs: dict = dict(
@@ -114,6 +118,14 @@ class Attend(nn.Module):
114
118
 
115
119
  self.add_zero_kv = add_zero_kv
116
120
 
121
+ # soft clamp attention logit value
122
+
123
+ if exists(logit_softclamp_value):
124
+ assert not flash, 'flash attention not compatible with logit softclamp value yet'
125
+ assert logit_softclamp_value > 0.
126
+
127
+ self.logit_softclamp_value = logit_softclamp_value
128
+
117
129
  # flash attention
118
130
 
119
131
  self.flash = flash
@@ -276,38 +288,41 @@ class Attend(nn.Module):
276
288
 
277
289
  kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
278
290
 
279
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
291
+ sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
280
292
 
281
293
  if exists(prev_attn):
282
- dots = dots + prev_attn
294
+ sim = sim + prev_attn
283
295
 
284
- qk_similarities = dots.clone()
296
+ qk_similarities = sim.clone()
285
297
 
286
298
  if self.talking_heads:
287
- dots = self.pre_softmax_talking_heads(dots)
299
+ sim = self.pre_softmax_talking_heads(sim)
288
300
 
289
301
  if exists(attn_bias):
290
- dots = dots + attn_bias
302
+ sim = sim + attn_bias
291
303
 
292
- i, j, dtype = *dots.shape[-2:], dots.dtype
304
+ i, j, dtype = *sim.shape[-2:], sim.dtype
293
305
 
294
- mask_value = -torch.finfo(dots.dtype).max
306
+ mask_value = -torch.finfo(sim.dtype).max
295
307
 
296
308
  if exists(self.sparse_topk) and self.sparse_topk < j:
297
- top_values, _ = dots.topk(self.sparse_topk, dim = -1)
298
- sparse_topk_mask = dots < top_values[..., -1:]
309
+ top_values, _ = sim.topk(self.sparse_topk, dim = -1)
310
+ sparse_topk_mask = sim < top_values[..., -1:]
299
311
  mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
300
312
 
301
313
  if exists(mask):
302
- dots = dots.masked_fill(~mask, mask_value)
314
+ sim = sim.masked_fill(~mask, mask_value)
303
315
 
304
316
  if causal:
305
317
  causal_mask = self.create_causal_mask(i, j, device = device)
306
- dots = dots.masked_fill(causal_mask, mask_value)
318
+ sim = sim.masked_fill(causal_mask, mask_value)
319
+
320
+ pre_softmax_attn = sim.clone()
307
321
 
308
- pre_softmax_attn = dots.clone()
322
+ if exists(self.logit_softclamp_value):
323
+ sim = softclamp(sim, self.logit_softclamp_value)
309
324
 
310
- attn = self.attn_fn(dots, dim = -1)
325
+ attn = self.attn_fn(sim, dim = -1)
311
326
  attn = attn.type(dtype)
312
327
 
313
328
  post_softmax_attn = attn.clone()
@@ -722,6 +722,7 @@ class Attention(Module):
722
722
  tensor_product = False, # https://arxiv.org/abs/2208.06061
723
723
  add_zero_kv = False, # same as add_zero_attn in pytorch
724
724
  rotary_embed_values = False,
725
+ logit_softclamp_value = None,
725
726
  onnxable = False
726
727
  ):
727
728
  super().__init__()
@@ -801,6 +802,7 @@ class Attention(Module):
801
802
  scale = qk_norm_scale if qk_norm else self.scale,
802
803
  add_zero_kv = add_zero_kv,
803
804
  flash = flash,
805
+ logit_softclamp_value = logit_softclamp_value,
804
806
  onnxable = onnxable
805
807
  )
806
808
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.2
3
+ Version: 1.30.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
2
+ x_transformers/attend.py,sha256=2SPHjXS_QAAZt04lHWGtdOypTExmo3BrbFhgcIQTk-Y,10671
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=uwooxffSXL2vTxLhDnkxF7fMe0gaCFW5WinuiR0fQpU,66191
7
+ x_transformers/x_transformers.py,sha256=BQypGJAoqXrAe_ek95wUcXdSQdAWjvw5mEti-H1JxcI,66288
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.2.dist-info/METADATA,sha256=ih2I-SzJQe_qASq_WzOwNuGKKrNTe5mvTba0ZrnMdfI,661
12
- x_transformers-1.30.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.2.dist-info/RECORD,,
10
+ x_transformers-1.30.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.3.dist-info/METADATA,sha256=FdVMtNhhggibbG3fJfCfGdRUAp6fLkTyG8KzDUq_r1Y,661
12
+ x_transformers-1.30.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.3.dist-info/RECORD,,