x-transformers 1.31.4__py3-none-any.whl → 1.31.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
69
69
 
70
70
  # main class
71
71
 
72
- class Attend(nn.Module):
72
+ class Attend(Module):
73
73
  def __init__(
74
74
  self,
75
75
  *,
@@ -81,7 +81,8 @@ class Attend(nn.Module):
81
81
  scale = None,
82
82
  qk_norm = False,
83
83
  flash = False,
84
- logit_softclamp_value = None,
84
+ softclamp_logits = False,
85
+ logit_softclamp_value = 30.,
85
86
  add_zero_kv = False,
86
87
  cope = None,
87
88
  onnxable = False,
@@ -123,10 +124,11 @@ class Attend(nn.Module):
123
124
 
124
125
  # soft clamp attention logit value
125
126
 
126
- if exists(logit_softclamp_value):
127
+ if softclamp_logits:
127
128
  assert not flash, 'flash attention not compatible with logit softclamp value yet'
128
129
  assert logit_softclamp_value > 0.
129
130
 
131
+ self.softclamp_logits = softclamp_logits
130
132
  self.logit_softclamp_value = logit_softclamp_value
131
133
 
132
134
  # contextual positional encoding
@@ -308,6 +310,9 @@ class Attend(nn.Module):
308
310
  if exists(attn_bias):
309
311
  sim = sim + attn_bias
310
312
 
313
+ if self.softclamp_logits:
314
+ sim = softclamp(sim, self.logit_softclamp_value)
315
+
311
316
  i, j, dtype = *sim.shape[-2:], sim.dtype
312
317
 
313
318
  mask_value = -torch.finfo(sim.dtype).max
@@ -329,9 +334,6 @@ class Attend(nn.Module):
329
334
 
330
335
  pre_softmax_attn = sim.clone()
331
336
 
332
- if exists(self.logit_softclamp_value):
333
- sim = softclamp(sim, self.logit_softclamp_value)
334
-
335
337
  attn = self.attn_fn(sim, dim = -1)
336
338
  attn = attn.type(dtype)
337
339
 
@@ -884,7 +884,8 @@ class Attention(Module):
884
884
  cope_max_pos = 16,
885
885
  cope_soft_onehot_pos = False,
886
886
  cope_talking_heads = False,
887
- logit_softclamp_value = None,
887
+ softclamp_logits = False,
888
+ logit_softclamp_value = 30.,
888
889
  onnxable = False
889
890
  ):
890
891
  super().__init__()
@@ -987,6 +988,7 @@ class Attention(Module):
987
988
  scale = qk_norm_scale if qk_norm else self.scale,
988
989
  add_zero_kv = add_zero_kv,
989
990
  flash = flash,
991
+ softclamp_logits = softclamp_logits,
990
992
  logit_softclamp_value = logit_softclamp_value,
991
993
  cope = cope,
992
994
  onnxable = onnxable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.4
3
+ Version: 1.31.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
2
+ x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=82gdvdwWWDWZVXTQRE6B4W-RQspg4SXhD8B6BV5pOo0,75789
7
+ x_transformers/x_transformers.py,sha256=VL9Dm8L5jnpgyt_V6DWGtLs9MeiTn7ZMCQdcHSFLVo8,75871
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.4.dist-info/METADATA,sha256=USK7uPjCATSS4URgfekTTmPhOtvHagHC02b-R-gxwME,661
12
- x_transformers-1.31.4.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
- x_transformers-1.31.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.4.dist-info/RECORD,,
10
+ x_transformers-1.31.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.5.dist-info/METADATA,sha256=o7z9SmDOf9BAU9FyMRuAbX4LIH8nIACA0acjGHPMz8s,661
12
+ x_transformers-1.31.5.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.5.dist-info/RECORD,,