x-transformers 1.30.1__py3-none-any.whl → 1.30.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +28 -13
- x_transformers/x_transformers.py +7 -2
- {x_transformers-1.30.1.dist-info → x_transformers-1.30.3.dist-info}/METADATA +1 -1
- {x_transformers-1.30.1.dist-info → x_transformers-1.30.3.dist-info}/RECORD +7 -7
- {x_transformers-1.30.1.dist-info → x_transformers-1.30.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.1.dist-info → x_transformers-1.30.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.1.dist-info → x_transformers-1.30.3.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -37,6 +37,9 @@ def default(val, d):
|
|
37
37
|
def compact(arr):
|
38
38
|
return [*filter(exists, arr)]
|
39
39
|
|
40
|
+
def softclamp(t, value):
|
41
|
+
return (t / value).tanh() * value
|
42
|
+
|
40
43
|
def once(fn):
|
41
44
|
called = False
|
42
45
|
@wraps(fn)
|
@@ -76,6 +79,7 @@ class Attend(nn.Module):
|
|
76
79
|
scale = None,
|
77
80
|
qk_norm = False,
|
78
81
|
flash = False,
|
82
|
+
logit_softclamp_value = None,
|
79
83
|
add_zero_kv = False,
|
80
84
|
onnxable = False,
|
81
85
|
sdp_kwargs: dict = dict(
|
@@ -114,6 +118,14 @@ class Attend(nn.Module):
|
|
114
118
|
|
115
119
|
self.add_zero_kv = add_zero_kv
|
116
120
|
|
121
|
+
# soft clamp attention logit value
|
122
|
+
|
123
|
+
if exists(logit_softclamp_value):
|
124
|
+
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
125
|
+
assert logit_softclamp_value > 0.
|
126
|
+
|
127
|
+
self.logit_softclamp_value = logit_softclamp_value
|
128
|
+
|
117
129
|
# flash attention
|
118
130
|
|
119
131
|
self.flash = flash
|
@@ -276,38 +288,41 @@ class Attend(nn.Module):
|
|
276
288
|
|
277
289
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
278
290
|
|
279
|
-
|
291
|
+
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
|
280
292
|
|
281
293
|
if exists(prev_attn):
|
282
|
-
|
294
|
+
sim = sim + prev_attn
|
283
295
|
|
284
|
-
qk_similarities =
|
296
|
+
qk_similarities = sim.clone()
|
285
297
|
|
286
298
|
if self.talking_heads:
|
287
|
-
|
299
|
+
sim = self.pre_softmax_talking_heads(sim)
|
288
300
|
|
289
301
|
if exists(attn_bias):
|
290
|
-
|
302
|
+
sim = sim + attn_bias
|
291
303
|
|
292
|
-
i, j, dtype = *
|
304
|
+
i, j, dtype = *sim.shape[-2:], sim.dtype
|
293
305
|
|
294
|
-
mask_value = -torch.finfo(
|
306
|
+
mask_value = -torch.finfo(sim.dtype).max
|
295
307
|
|
296
308
|
if exists(self.sparse_topk) and self.sparse_topk < j:
|
297
|
-
top_values, _ =
|
298
|
-
sparse_topk_mask =
|
309
|
+
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
310
|
+
sparse_topk_mask = sim < top_values[..., -1:]
|
299
311
|
mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
|
300
312
|
|
301
313
|
if exists(mask):
|
302
|
-
|
314
|
+
sim = sim.masked_fill(~mask, mask_value)
|
303
315
|
|
304
316
|
if causal:
|
305
317
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
306
|
-
|
318
|
+
sim = sim.masked_fill(causal_mask, mask_value)
|
319
|
+
|
320
|
+
pre_softmax_attn = sim.clone()
|
307
321
|
|
308
|
-
|
322
|
+
if exists(self.logit_softclamp_value):
|
323
|
+
sim = softclamp(sim, self.logit_softclamp_value)
|
309
324
|
|
310
|
-
attn = self.attn_fn(
|
325
|
+
attn = self.attn_fn(sim, dim = -1)
|
311
326
|
attn = attn.type(dtype)
|
312
327
|
|
313
328
|
post_softmax_attn = attn.clone()
|
x_transformers/x_transformers.py
CHANGED
@@ -468,7 +468,8 @@ def rotate_half(x):
|
|
468
468
|
|
469
469
|
@autocast(enabled = False)
|
470
470
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
471
|
-
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
471
|
+
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
472
|
+
|
472
473
|
freqs = freqs[-seq_len:, :]
|
473
474
|
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
474
475
|
|
@@ -478,7 +479,9 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
478
479
|
# partial rotary embeddings, Wang et al. GPT-J
|
479
480
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
480
481
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
481
|
-
|
482
|
+
out = torch.cat((t, t_unrotated), dim = -1)
|
483
|
+
|
484
|
+
return out.type(orig_dtype)
|
482
485
|
|
483
486
|
# norms
|
484
487
|
|
@@ -719,6 +722,7 @@ class Attention(Module):
|
|
719
722
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
720
723
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
721
724
|
rotary_embed_values = False,
|
725
|
+
logit_softclamp_value = None,
|
722
726
|
onnxable = False
|
723
727
|
):
|
724
728
|
super().__init__()
|
@@ -798,6 +802,7 @@ class Attention(Module):
|
|
798
802
|
scale = qk_norm_scale if qk_norm else self.scale,
|
799
803
|
add_zero_kv = add_zero_kv,
|
800
804
|
flash = flash,
|
805
|
+
logit_softclamp_value = logit_softclamp_value,
|
801
806
|
onnxable = onnxable
|
802
807
|
)
|
803
808
|
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=2SPHjXS_QAAZt04lHWGtdOypTExmo3BrbFhgcIQTk-Y,10671
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=BQypGJAoqXrAe_ek95wUcXdSQdAWjvw5mEti-H1JxcI,66288
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.3.dist-info/METADATA,sha256=FdVMtNhhggibbG3fJfCfGdRUAp6fLkTyG8KzDUq_r1Y,661
|
12
|
+
x_transformers-1.30.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|