x-transformers 1.31.4__py3-none-any.whl → 1.31.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +8 -6
- x_transformers/x_transformers.py +3 -1
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.5.dist-info}/METADATA +1 -1
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.5.dist-info}/RECORD +7 -7
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.5.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
|
|
69
69
|
|
70
70
|
# main class
|
71
71
|
|
72
|
-
class Attend(
|
72
|
+
class Attend(Module):
|
73
73
|
def __init__(
|
74
74
|
self,
|
75
75
|
*,
|
@@ -81,7 +81,8 @@ class Attend(nn.Module):
|
|
81
81
|
scale = None,
|
82
82
|
qk_norm = False,
|
83
83
|
flash = False,
|
84
|
-
|
84
|
+
softclamp_logits = False,
|
85
|
+
logit_softclamp_value = 30.,
|
85
86
|
add_zero_kv = False,
|
86
87
|
cope = None,
|
87
88
|
onnxable = False,
|
@@ -123,10 +124,11 @@ class Attend(nn.Module):
|
|
123
124
|
|
124
125
|
# soft clamp attention logit value
|
125
126
|
|
126
|
-
if
|
127
|
+
if softclamp_logits:
|
127
128
|
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
128
129
|
assert logit_softclamp_value > 0.
|
129
130
|
|
131
|
+
self.softclamp_logits = softclamp_logits
|
130
132
|
self.logit_softclamp_value = logit_softclamp_value
|
131
133
|
|
132
134
|
# contextual positional encoding
|
@@ -308,6 +310,9 @@ class Attend(nn.Module):
|
|
308
310
|
if exists(attn_bias):
|
309
311
|
sim = sim + attn_bias
|
310
312
|
|
313
|
+
if self.softclamp_logits:
|
314
|
+
sim = softclamp(sim, self.logit_softclamp_value)
|
315
|
+
|
311
316
|
i, j, dtype = *sim.shape[-2:], sim.dtype
|
312
317
|
|
313
318
|
mask_value = -torch.finfo(sim.dtype).max
|
@@ -329,9 +334,6 @@ class Attend(nn.Module):
|
|
329
334
|
|
330
335
|
pre_softmax_attn = sim.clone()
|
331
336
|
|
332
|
-
if exists(self.logit_softclamp_value):
|
333
|
-
sim = softclamp(sim, self.logit_softclamp_value)
|
334
|
-
|
335
337
|
attn = self.attn_fn(sim, dim = -1)
|
336
338
|
attn = attn.type(dtype)
|
337
339
|
|
x_transformers/x_transformers.py
CHANGED
@@ -884,7 +884,8 @@ class Attention(Module):
|
|
884
884
|
cope_max_pos = 16,
|
885
885
|
cope_soft_onehot_pos = False,
|
886
886
|
cope_talking_heads = False,
|
887
|
-
|
887
|
+
softclamp_logits = False,
|
888
|
+
logit_softclamp_value = 30.,
|
888
889
|
onnxable = False
|
889
890
|
):
|
890
891
|
super().__init__()
|
@@ -987,6 +988,7 @@ class Attention(Module):
|
|
987
988
|
scale = qk_norm_scale if qk_norm else self.scale,
|
988
989
|
add_zero_kv = add_zero_kv,
|
989
990
|
flash = flash,
|
991
|
+
softclamp_logits = softclamp_logits,
|
990
992
|
logit_softclamp_value = logit_softclamp_value,
|
991
993
|
cope = cope,
|
992
994
|
onnxable = onnxable
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=VL9Dm8L5jnpgyt_V6DWGtLs9MeiTn7ZMCQdcHSFLVo8,75871
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.5.dist-info/METADATA,sha256=o7z9SmDOf9BAU9FyMRuAbX4LIH8nIACA0acjGHPMz8s,661
|
12
|
+
x_transformers-1.31.5.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|