x-transformers 1.23.0__py3-none-any.whl → 1.23.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -346,79 +346,3 @@ class Attend(nn.Module):
346
346
  )
347
347
 
348
348
  return out, intermediates
349
-
350
- # cascading heads logic
351
-
352
- def to_single_heads(t, dim = 1):
353
- heads = t.unbind(dim = dim)
354
- return tuple(head.unsqueeze(dim) for head in heads)
355
-
356
- class CascadingHeads(nn.Module):
357
- def __init__(self, attend: Attend):
358
- super().__init__()
359
- self.attend = attend
360
-
361
- def forward(
362
- self,
363
- q, k, v,
364
- mask = None,
365
- attn_bias = None,
366
- prev_attn = None
367
- ):
368
- assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same'
369
-
370
- # split inputs into per-head inputs
371
-
372
- heads = q.shape[1]
373
-
374
- queries = to_single_heads(q)
375
- keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
376
- values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)
377
-
378
- mask = (mask,) * heads
379
-
380
- attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
381
- prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)
382
-
383
- # now loop through each head, without output of previous head summed with the next head
384
- # thus cascading
385
-
386
- all_outs = []
387
- all_intermediates = []
388
-
389
- prev_head_out = None
390
-
391
- for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn):
392
-
393
- if exists(prev_head_out):
394
- h_q = h_q + prev_head_out
395
-
396
- out, intermediates = self.attend(
397
- h_q, h_k, h_v,
398
- mask = h_mask,
399
- attn_bias = h_attn_bias,
400
- prev_attn = h_prev_attn
401
- )
402
-
403
- prev_head_out = out
404
-
405
- all_outs.append(out)
406
- all_intermediates.append(intermediates)
407
-
408
- # cat all output heads
409
-
410
- all_outs = torch.cat(all_outs, dim = 1)
411
-
412
- # cat all intermediates, if they exist
413
-
414
- qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates))
415
-
416
- qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn))
417
-
418
- aggregated_intermediates = Intermediates(
419
- qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None,
420
- pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None,
421
- post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None
422
- )
423
-
424
- return all_outs, aggregated_intermediates
@@ -14,7 +14,7 @@ from typing import List, Callable, Optional
14
14
  from einops import rearrange, repeat, reduce, pack, unpack
15
15
  from einops.layers.torch import Rearrange
16
16
 
17
- from x_transformers.attend import Attend, Intermediates, CascadingHeads
17
+ from x_transformers.attend import Attend, Intermediates
18
18
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
19
19
 
20
20
  # constants
@@ -650,6 +650,7 @@ class Attention(nn.Module):
650
650
  num_mem_kv = 0,
651
651
  dropout = 0.,
652
652
  on_attn = False,
653
+ gate_value_heads = False,
653
654
  gate_values = False,
654
655
  zero_init_output = False,
655
656
  max_attend_past = None,
@@ -662,7 +663,6 @@ class Attention(nn.Module):
662
663
  shared_kv = False,
663
664
  value_dim_head = None,
664
665
  tensor_product = False, # https://arxiv.org/abs/2208.06061
665
- cascading_heads = False,
666
666
  add_zero_kv = False, # same as add_zero_attn in pytorch
667
667
  rotary_embed_values = False,
668
668
  onnxable = False
@@ -674,7 +674,6 @@ class Attention(nn.Module):
674
674
  self.causal = causal
675
675
  self.max_attend_past = max_attend_past
676
676
 
677
-
678
677
  assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
679
678
 
680
679
  value_dim_head = default(value_dim_head, dim_head)
@@ -705,7 +704,14 @@ class Attention(nn.Module):
705
704
  if gate_values:
706
705
  self.to_v_gate = nn.Linear(dim, out_dim)
707
706
  nn.init.constant_(self.to_v_gate.weight, 0)
708
- nn.init.constant_(self.to_v_gate.bias, 1)
707
+ nn.init.constant_(self.to_v_gate.bias, 10)
708
+
709
+ # add per head gating of the output values, from 'Attend to nothing' paper
710
+ self.to_v_head_gate = None
711
+ if gate_value_heads:
712
+ self.to_v_head_gate = nn.Linear(dim, heads)
713
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
714
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
709
715
 
710
716
  # cosine sim attention
711
717
  self.qk_norm = qk_norm
@@ -738,10 +744,6 @@ class Attention(nn.Module):
738
744
  onnxable = onnxable
739
745
  )
740
746
 
741
- if cascading_heads:
742
- # cascading heads - wrap the Attend logic
743
- self.attend = CascadingHeads(self.attend)
744
-
745
747
  # head scaling
746
748
  self.head_scale = head_scale
747
749
  if head_scale:
@@ -911,6 +913,12 @@ class Attention(nn.Module):
911
913
  if head_scale:
912
914
  out = out * self.head_scale_params
913
915
 
916
+ # per head gating, from https://arxiv.org/abs/2306.12929
917
+
918
+ if exists(self.to_v_head_gate):
919
+ head_gate = self.to_v_head_gate(x)
920
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
921
+
914
922
  # merge heads
915
923
 
916
924
  out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.0
3
+ Version: 1.23.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=T2EzF_o0qVxIC0WvWoDDO2sY6J3h-aXAK0vN4McDgbc,13819
2
+ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
3
3
  x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=o8PJ0aZatavxyqx80JLh6Lk-8_C8H-HRwlc1dHsIV6g,60760
6
+ x_transformers/x_transformers.py,sha256=KQ9mU_jE27whl6yQI67grF0S8Xhd3GndnM6Yd0-q-lw,61162
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.0.dist-info/METADATA,sha256=cA6JGJ3U7NSpBXjaNvXRkpLTF4YXqLpTQA-8C-RHWk8,661
10
- x_transformers-1.23.0.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.0.dist-info/RECORD,,
8
+ x_transformers-1.23.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.2.dist-info/METADATA,sha256=8h0sbx8-4yNTOJuAZLbe5HQ16hsmZI1M_mT-rMIIMJc,661
10
+ x_transformers-1.23.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.2.dist-info/RECORD,,