x-transformers 1.23.0__py3-none-any.whl → 1.23.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +0 -76
- x_transformers/x_transformers.py +16 -8
- {x_transformers-1.23.0.dist-info → x_transformers-1.23.2.dist-info}/METADATA +1 -1
- {x_transformers-1.23.0.dist-info → x_transformers-1.23.2.dist-info}/RECORD +7 -7
- {x_transformers-1.23.0.dist-info → x_transformers-1.23.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.0.dist-info → x_transformers-1.23.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.0.dist-info → x_transformers-1.23.2.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -346,79 +346,3 @@ class Attend(nn.Module):
|
|
346
346
|
)
|
347
347
|
|
348
348
|
return out, intermediates
|
349
|
-
|
350
|
-
# cascading heads logic
|
351
|
-
|
352
|
-
def to_single_heads(t, dim = 1):
|
353
|
-
heads = t.unbind(dim = dim)
|
354
|
-
return tuple(head.unsqueeze(dim) for head in heads)
|
355
|
-
|
356
|
-
class CascadingHeads(nn.Module):
|
357
|
-
def __init__(self, attend: Attend):
|
358
|
-
super().__init__()
|
359
|
-
self.attend = attend
|
360
|
-
|
361
|
-
def forward(
|
362
|
-
self,
|
363
|
-
q, k, v,
|
364
|
-
mask = None,
|
365
|
-
attn_bias = None,
|
366
|
-
prev_attn = None
|
367
|
-
):
|
368
|
-
assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same'
|
369
|
-
|
370
|
-
# split inputs into per-head inputs
|
371
|
-
|
372
|
-
heads = q.shape[1]
|
373
|
-
|
374
|
-
queries = to_single_heads(q)
|
375
|
-
keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
|
376
|
-
values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)
|
377
|
-
|
378
|
-
mask = (mask,) * heads
|
379
|
-
|
380
|
-
attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
|
381
|
-
prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)
|
382
|
-
|
383
|
-
# now loop through each head, without output of previous head summed with the next head
|
384
|
-
# thus cascading
|
385
|
-
|
386
|
-
all_outs = []
|
387
|
-
all_intermediates = []
|
388
|
-
|
389
|
-
prev_head_out = None
|
390
|
-
|
391
|
-
for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn):
|
392
|
-
|
393
|
-
if exists(prev_head_out):
|
394
|
-
h_q = h_q + prev_head_out
|
395
|
-
|
396
|
-
out, intermediates = self.attend(
|
397
|
-
h_q, h_k, h_v,
|
398
|
-
mask = h_mask,
|
399
|
-
attn_bias = h_attn_bias,
|
400
|
-
prev_attn = h_prev_attn
|
401
|
-
)
|
402
|
-
|
403
|
-
prev_head_out = out
|
404
|
-
|
405
|
-
all_outs.append(out)
|
406
|
-
all_intermediates.append(intermediates)
|
407
|
-
|
408
|
-
# cat all output heads
|
409
|
-
|
410
|
-
all_outs = torch.cat(all_outs, dim = 1)
|
411
|
-
|
412
|
-
# cat all intermediates, if they exist
|
413
|
-
|
414
|
-
qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates))
|
415
|
-
|
416
|
-
qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn))
|
417
|
-
|
418
|
-
aggregated_intermediates = Intermediates(
|
419
|
-
qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None,
|
420
|
-
pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None,
|
421
|
-
post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None
|
422
|
-
)
|
423
|
-
|
424
|
-
return all_outs, aggregated_intermediates
|
x_transformers/x_transformers.py
CHANGED
@@ -14,7 +14,7 @@ from typing import List, Callable, Optional
|
|
14
14
|
from einops import rearrange, repeat, reduce, pack, unpack
|
15
15
|
from einops.layers.torch import Rearrange
|
16
16
|
|
17
|
-
from x_transformers.attend import Attend, Intermediates
|
17
|
+
from x_transformers.attend import Attend, Intermediates
|
18
18
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
19
19
|
|
20
20
|
# constants
|
@@ -650,6 +650,7 @@ class Attention(nn.Module):
|
|
650
650
|
num_mem_kv = 0,
|
651
651
|
dropout = 0.,
|
652
652
|
on_attn = False,
|
653
|
+
gate_value_heads = False,
|
653
654
|
gate_values = False,
|
654
655
|
zero_init_output = False,
|
655
656
|
max_attend_past = None,
|
@@ -662,7 +663,6 @@ class Attention(nn.Module):
|
|
662
663
|
shared_kv = False,
|
663
664
|
value_dim_head = None,
|
664
665
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
665
|
-
cascading_heads = False,
|
666
666
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
667
667
|
rotary_embed_values = False,
|
668
668
|
onnxable = False
|
@@ -674,7 +674,6 @@ class Attention(nn.Module):
|
|
674
674
|
self.causal = causal
|
675
675
|
self.max_attend_past = max_attend_past
|
676
676
|
|
677
|
-
|
678
677
|
assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
|
679
678
|
|
680
679
|
value_dim_head = default(value_dim_head, dim_head)
|
@@ -705,7 +704,14 @@ class Attention(nn.Module):
|
|
705
704
|
if gate_values:
|
706
705
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
707
706
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
708
|
-
nn.init.constant_(self.to_v_gate.bias,
|
707
|
+
nn.init.constant_(self.to_v_gate.bias, 10)
|
708
|
+
|
709
|
+
# add per head gating of the output values, from 'Attend to nothing' paper
|
710
|
+
self.to_v_head_gate = None
|
711
|
+
if gate_value_heads:
|
712
|
+
self.to_v_head_gate = nn.Linear(dim, heads)
|
713
|
+
nn.init.constant_(self.to_v_head_gate.weight, 0)
|
714
|
+
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
709
715
|
|
710
716
|
# cosine sim attention
|
711
717
|
self.qk_norm = qk_norm
|
@@ -738,10 +744,6 @@ class Attention(nn.Module):
|
|
738
744
|
onnxable = onnxable
|
739
745
|
)
|
740
746
|
|
741
|
-
if cascading_heads:
|
742
|
-
# cascading heads - wrap the Attend logic
|
743
|
-
self.attend = CascadingHeads(self.attend)
|
744
|
-
|
745
747
|
# head scaling
|
746
748
|
self.head_scale = head_scale
|
747
749
|
if head_scale:
|
@@ -911,6 +913,12 @@ class Attention(nn.Module):
|
|
911
913
|
if head_scale:
|
912
914
|
out = out * self.head_scale_params
|
913
915
|
|
916
|
+
# per head gating, from https://arxiv.org/abs/2306.12929
|
917
|
+
|
918
|
+
if exists(self.to_v_head_gate):
|
919
|
+
head_gate = self.to_v_head_gate(x)
|
920
|
+
out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
|
921
|
+
|
914
922
|
# merge heads
|
915
923
|
|
916
924
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=KQ9mU_jE27whl6yQI67grF0S8Xhd3GndnM6Yd0-q-lw,61162
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.
|
9
|
-
x_transformers-1.23.
|
10
|
-
x_transformers-1.23.
|
11
|
-
x_transformers-1.23.
|
12
|
-
x_transformers-1.23.
|
8
|
+
x_transformers-1.23.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.2.dist-info/METADATA,sha256=8h0sbx8-4yNTOJuAZLbe5HQ16hsmZI1M_mT-rMIIMJc,661
|
10
|
+
x_transformers-1.23.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|