x-transformers 1.31.4__py3-none-any.whl → 1.31.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +8 -6
- x_transformers/x_transformers.py +7 -3
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.6.dist-info}/METADATA +1 -1
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.6.dist-info}/RECORD +7 -7
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.6.dist-info}/WHEEL +0 -0
- {x_transformers-1.31.4.dist-info → x_transformers-1.31.6.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
|
|
69
69
|
|
70
70
|
# main class
|
71
71
|
|
72
|
-
class Attend(
|
72
|
+
class Attend(Module):
|
73
73
|
def __init__(
|
74
74
|
self,
|
75
75
|
*,
|
@@ -81,7 +81,8 @@ class Attend(nn.Module):
|
|
81
81
|
scale = None,
|
82
82
|
qk_norm = False,
|
83
83
|
flash = False,
|
84
|
-
|
84
|
+
softclamp_logits = False,
|
85
|
+
logit_softclamp_value = 30.,
|
85
86
|
add_zero_kv = False,
|
86
87
|
cope = None,
|
87
88
|
onnxable = False,
|
@@ -123,10 +124,11 @@ class Attend(nn.Module):
|
|
123
124
|
|
124
125
|
# soft clamp attention logit value
|
125
126
|
|
126
|
-
if
|
127
|
+
if softclamp_logits:
|
127
128
|
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
128
129
|
assert logit_softclamp_value > 0.
|
129
130
|
|
131
|
+
self.softclamp_logits = softclamp_logits
|
130
132
|
self.logit_softclamp_value = logit_softclamp_value
|
131
133
|
|
132
134
|
# contextual positional encoding
|
@@ -308,6 +310,9 @@ class Attend(nn.Module):
|
|
308
310
|
if exists(attn_bias):
|
309
311
|
sim = sim + attn_bias
|
310
312
|
|
313
|
+
if self.softclamp_logits:
|
314
|
+
sim = softclamp(sim, self.logit_softclamp_value)
|
315
|
+
|
311
316
|
i, j, dtype = *sim.shape[-2:], sim.dtype
|
312
317
|
|
313
318
|
mask_value = -torch.finfo(sim.dtype).max
|
@@ -329,9 +334,6 @@ class Attend(nn.Module):
|
|
329
334
|
|
330
335
|
pre_softmax_attn = sim.clone()
|
331
336
|
|
332
|
-
if exists(self.logit_softclamp_value):
|
333
|
-
sim = softclamp(sim, self.logit_softclamp_value)
|
334
|
-
|
335
337
|
attn = self.attn_fn(sim, dim = -1)
|
336
338
|
attn = attn.type(dtype)
|
337
339
|
|
x_transformers/x_transformers.py
CHANGED
@@ -884,7 +884,8 @@ class Attention(Module):
|
|
884
884
|
cope_max_pos = 16,
|
885
885
|
cope_soft_onehot_pos = False,
|
886
886
|
cope_talking_heads = False,
|
887
|
-
|
887
|
+
softclamp_logits = False,
|
888
|
+
logit_softclamp_value = 30.,
|
888
889
|
onnxable = False
|
889
890
|
):
|
890
891
|
super().__init__()
|
@@ -987,6 +988,7 @@ class Attention(Module):
|
|
987
988
|
scale = qk_norm_scale if qk_norm else self.scale,
|
988
989
|
add_zero_kv = add_zero_kv,
|
989
990
|
flash = flash,
|
991
|
+
softclamp_logits = softclamp_logits,
|
990
992
|
logit_softclamp_value = logit_softclamp_value,
|
991
993
|
cope = cope,
|
992
994
|
onnxable = onnxable
|
@@ -1265,7 +1267,8 @@ class AttentionLayers(Module):
|
|
1265
1267
|
scale_residual_constant = 1.,
|
1266
1268
|
shift_tokens = 0,
|
1267
1269
|
sandwich_norm = False,
|
1268
|
-
|
1270
|
+
softclamp_output = False,
|
1271
|
+
softclamp_output_value = 50.,
|
1269
1272
|
resi_dual = False,
|
1270
1273
|
resi_dual_scale = 1.,
|
1271
1274
|
zero_init_branch_output = False,
|
@@ -1482,6 +1485,7 @@ class AttentionLayers(Module):
|
|
1482
1485
|
# optional soft clamping just before the final norm
|
1483
1486
|
# used in gemma 2
|
1484
1487
|
|
1488
|
+
self.softclamp_output = softclamp_output
|
1485
1489
|
self.softclamp_output_value = softclamp_output_value
|
1486
1490
|
|
1487
1491
|
# whether it has post norm
|
@@ -1715,7 +1719,7 @@ class AttentionLayers(Module):
|
|
1715
1719
|
if return_hiddens:
|
1716
1720
|
layer_hiddens.append(x)
|
1717
1721
|
|
1718
|
-
if
|
1722
|
+
if self.softclamp_output:
|
1719
1723
|
x = softclamp(x, self.softclamp_output_value)
|
1720
1724
|
|
1721
1725
|
final_norm = self.final_norm
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=xPwzR3bd8BS_ChEcz0UxsNtx99u4UbP8jg1fFIRDGUw,75925
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.6.dist-info/METADATA,sha256=O2MZXNuX-jqrAdpcxkIDH4J0T63t5nt9utI8FHrRIA0,661
|
12
|
+
x_transformers-1.31.6.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|