x-transformers 1.21.2__py3-none-any.whl → 1.21.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,7 @@ from collections import namedtuple
11
11
  from dataclasses import dataclass
12
12
  from typing import List, Callable, Optional
13
13
 
14
- from einops import rearrange, repeat, reduce
14
+ from einops import rearrange, repeat, reduce, pack, unpack
15
15
  from einops.layers.torch import Rearrange
16
16
 
17
17
  from x_transformers.attend import Attend, Intermediates, CascadingHeads
@@ -27,6 +27,7 @@ class LayerIntermediates:
27
27
  attn_intermediates: Optional[List[Intermediates]] = None
28
28
  layer_hiddens: Optional[List[Tensor]] = None
29
29
  attn_z_loss: Optional[Tensor] = None
30
+ mems: Optional[Tensor] = None
30
31
 
31
32
  # helpers
32
33
 
@@ -778,8 +779,8 @@ class Attention(nn.Module):
778
779
  r_input = x
779
780
 
780
781
  if exists(mem):
781
- k_input = torch.cat((mem, k_input), dim = -2)
782
- v_input = torch.cat((mem, v_input), dim = -2)
782
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
783
+ v_input, _ = pack([mem, v_input], 'b * d')
783
784
 
784
785
  q = self.to_q(q_input)
785
786
  k = self.to_k(k_input)
@@ -792,11 +793,21 @@ class Attention(nn.Module):
792
793
 
793
794
  if exists(cache) and not has_context:
794
795
  ck, cv = cache.cached_kv
796
+
797
+ if exists(mem):
798
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
799
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
800
+
795
801
  k = torch.cat((ck, k), dim = -2)
796
802
  v = torch.cat((cv, v), dim = -2)
797
803
 
804
+ if exists(mem):
805
+ k = torch.cat((mk, k), dim = -2)
806
+ v = torch.cat((mv, v), dim = -2)
807
+
798
808
  if return_intermediates:
799
- cached_kv = (k, v)
809
+ mem_len = mem.shape[-2] if exists(mem) else 0
810
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
800
811
 
801
812
  if self.qk_norm:
802
813
  qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
@@ -941,6 +952,8 @@ class AttentionLayers(nn.Module):
941
952
  custom_layers = None,
942
953
  sandwich_coef = None,
943
954
  par_ratio = None,
955
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
956
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
944
957
  residual_attn = False,
945
958
  cross_residual_attn = False,
946
959
  macaron = False,
@@ -1046,6 +1059,15 @@ class AttentionLayers(nn.Module):
1046
1059
  attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1047
1060
  ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1048
1061
 
1062
+ # setup weight tying, which is a special case of `layer_execute_order`
1063
+
1064
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1065
+
1066
+ if weight_tie_layers:
1067
+ assert not exists(layers_execute_order)
1068
+ layers_execute_order = tuple(range(len(default_block))) * depth
1069
+ depth = 1
1070
+
1049
1071
  # calculate layer block order
1050
1072
 
1051
1073
  if exists(custom_layers):
@@ -1068,6 +1090,10 @@ class AttentionLayers(nn.Module):
1068
1090
  layer_types = default_block * depth
1069
1091
 
1070
1092
  self.layer_types = layer_types
1093
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1094
+
1095
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1096
+
1071
1097
  self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1072
1098
 
1073
1099
  # stochastic depth
@@ -1176,7 +1202,19 @@ class AttentionLayers(nn.Module):
1176
1202
 
1177
1203
  outer_residual = x * self.resi_dual_scale
1178
1204
 
1179
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
1205
+ # get layers to be executed
1206
+
1207
+ layer_variables = (
1208
+ self.layer_types,
1209
+ self.layers,
1210
+ self.layer_dropouts
1211
+ )
1212
+
1213
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1214
+
1215
+ # go through the attention and feedforward layers
1216
+
1217
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1180
1218
  is_last = ind == (len(self.layers) - 1)
1181
1219
 
1182
1220
  if self.training and layer_dropout > 0. and random() < layer_dropout:
@@ -1475,14 +1513,18 @@ class TransformerWrapper(nn.Module):
1475
1513
  intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
1476
1514
  return_intermediates = True
1477
1515
 
1478
- if return_intermediates:
1479
- return out, intermediates
1480
-
1481
1516
  if return_mems:
1482
1517
  hiddens = intermediates.hiddens
1483
1518
  new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
1484
1519
  new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
1485
- return out, new_mems
1520
+
1521
+ if not return_intermediates:
1522
+ return out, new_mems
1523
+
1524
+ intermediates.mems = new_mems
1525
+
1526
+ if return_intermediates:
1527
+ return out, intermediates
1486
1528
 
1487
1529
  if return_attn:
1488
1530
  attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
@@ -67,6 +67,7 @@ class XLAutoregressiveWrapper(nn.Module):
67
67
  curr_pos = len(all_leading_tokens) * max_seq_len
68
68
  curr_mems = mems
69
69
 
70
+ cache = None
70
71
  out = start_tokens
71
72
 
72
73
  for _ in range(seq_len):
@@ -75,13 +76,17 @@ class XLAutoregressiveWrapper(nn.Module):
75
76
 
76
77
  x = out[:, curr_pos:]
77
78
 
78
- logits, mems = self.net(
79
+ logits, cache = self.net(
79
80
  x,
80
81
  mems = curr_mems,
82
+ cache = cache,
81
83
  return_mems = True,
84
+ return_intermediates = True,
82
85
  **kwargs
83
86
  )
84
87
 
88
+ mems = cache.mems
89
+
85
90
  logits = logits[:, -1]
86
91
  filtered_logits = filter_logits_fn(logits, thres = filter_thres)
87
92
  probs = F.softmax(filtered_logits / temperature, dim=-1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.21.2
3
+ Version: 1.21.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=_LepMOwph_o3jio6tur8LEUBPM-2YIn7NpuBmqhU47E,1238
3
3
  x_transformers/autoregressive_wrapper.py,sha256=djA4nfE6_92SAzX1JI0KaC7krdLz1mvnZlaVOaerHDg,5372
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=Sw5FJlyDQ0fEINLw3G9aHJlLy3EBir0NphBVZ_gLiuA,56508
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.21.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.21.2.dist-info/METADATA,sha256=q47HwjcwIc31Liw_UWeHQKhjRZaWjrSLQo5ryHgCfkQ,661
10
- x_transformers-1.21.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.21.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.21.2.dist-info/RECORD,,
6
+ x_transformers/x_transformers.py,sha256=vss6ISABCV74wgEapIt8nPK50j9-QU54hugRiaBh-sw,58088
7
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
+ x_transformers-1.21.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.21.4.dist-info/METADATA,sha256=HzzezgisQhEH2H6D2tI-JDqNghpuUi6pDlHg0AI976U,661
10
+ x_transformers-1.21.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.21.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.21.4.dist-info/RECORD,,