x-transformers 1.34.0__py3-none-any.whl → 1.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -138,9 +138,27 @@ class Attend(Module):
138
138
  # flash attention
139
139
 
140
140
  self.flash = flash
141
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
142
141
 
143
- self.sdp_kwargs = sdp_kwargs
142
+ torch_version = version.parse(torch.__version__)
143
+ assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
144
+
145
+ # torch 2.3 uses new backend and context manager
146
+
147
+ if torch_version >= version.parse('2.3'):
148
+ from torch.nn.attention import SDPBackend
149
+
150
+ str_to_backend = dict(
151
+ enable_flash = SDPBackend.FLASH_ATTENTION,
152
+ enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
153
+ enable_math = SDPBackend.MATH,
154
+ enable_cudnn = SDPBackend.CUDNN_ATTENTION
155
+ )
156
+
157
+ sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
158
+
159
+ self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
160
+ else:
161
+ self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
144
162
 
145
163
  def flash_attn(
146
164
  self,
@@ -231,7 +249,7 @@ class Attend(Module):
231
249
 
232
250
  # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
233
251
 
234
- with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
252
+ with self.sdp_context_manager():
235
253
  out = F.scaled_dot_product_attention(
236
254
  q, k, v,
237
255
  attn_mask = mask,
@@ -810,6 +810,19 @@ class AdaptiveLayerScale(Module):
810
810
  out, *rest = out
811
811
  return out * gamma, *rest
812
812
 
813
+ # skip connection combining
814
+
815
+ class ConcatCombine(Module):
816
+ def __init__(self, dim, prev_layer_ind):
817
+ super().__init__()
818
+ self.prev_layer_ind = prev_layer_ind
819
+ self.combine = nn.Linear(dim * 2, dim, bias = False)
820
+
821
+ def forward(self, x, prev_layers: list[Tensor]):
822
+ skip = prev_layers[self.prev_layer_ind]
823
+ concatted_skip = torch.cat((skip, x), dim = -1)
824
+ return self.combine(concatted_skip)
825
+
813
826
  # feedforward
814
827
 
815
828
  class GLU(Module):
@@ -1307,6 +1320,7 @@ class AttentionLayers(Module):
1307
1320
  disable_abs_pos_emb = None,
1308
1321
  use_layerscale = False,
1309
1322
  layerscale_init_value = 0.,
1323
+ unet_skips = False,
1310
1324
  **kwargs
1311
1325
  ):
1312
1326
  super().__init__()
@@ -1468,6 +1482,8 @@ class AttentionLayers(Module):
1468
1482
 
1469
1483
  # calculate layer block order
1470
1484
 
1485
+ len_default_block = 1
1486
+
1471
1487
  if exists(custom_layers):
1472
1488
  layer_types = custom_layers
1473
1489
  elif exists(par_ratio):
@@ -1487,6 +1503,7 @@ class AttentionLayers(Module):
1487
1503
  else:
1488
1504
  assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
1489
1505
  layer_types = default_block * depth
1506
+ len_default_block = len(default_block)
1490
1507
 
1491
1508
  self.layer_types = layer_types
1492
1509
  self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
@@ -1522,11 +1539,31 @@ class AttentionLayers(Module):
1522
1539
 
1523
1540
  self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1524
1541
 
1542
+ # whether unet or not
1543
+
1544
+ self.unet_skips = unet_skips
1545
+ num_skips = self.depth // len_default_block
1546
+
1547
+ assert not (unet_skips and num_skips == 0), 'must have depth of at least 2 for unet skip connections'
1548
+
1549
+ skip_indices = [i * len_default_block for i in range(num_skips)]
1550
+
1551
+ self.skip_combines = ModuleList([])
1552
+
1525
1553
  # iterate and construct layers
1526
1554
 
1527
1555
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1556
+
1557
+ # `ind` is the index of each module - attention, feedforward, cross attention
1558
+ # but `block_ind` refers to the typical enumeration of a transformer block (attn + ff + [optional] cross attn)
1559
+
1560
+ block_begin = divisible_by(ind, len_default_block)
1561
+ block_ind = ind // len_default_block
1562
+
1528
1563
  is_last_layer = ind == (len(self.layer_types) - 1)
1529
1564
 
1565
+ # attention, cross attention, feedforward
1566
+
1530
1567
  if layer_type == 'a':
1531
1568
  layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1532
1569
  elif layer_type == 'c':
@@ -1548,6 +1585,14 @@ class AttentionLayers(Module):
1548
1585
  residual_fn = GRUGating if gate_residual else Residual
1549
1586
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1550
1587
 
1588
+ # handle unet skip connection
1589
+
1590
+ skip_combine = None
1591
+ is_latter_half = block_begin and block_ind >= (self.depth / 2)
1592
+
1593
+ if self.unet_skips and is_latter_half:
1594
+ skip_combine = ConcatCombine(dim, skip_indices.pop())
1595
+
1551
1596
  # all normalizations of the layer
1552
1597
 
1553
1598
  pre_branch_norm = norm_fn() if pre_norm else None
@@ -1560,6 +1605,8 @@ class AttentionLayers(Module):
1560
1605
  post_main_norm
1561
1606
  ])
1562
1607
 
1608
+ self.skip_combines.append(skip_combine)
1609
+
1563
1610
  self.layers.append(ModuleList([
1564
1611
  norms,
1565
1612
  layer,
@@ -1670,6 +1717,7 @@ class AttentionLayers(Module):
1670
1717
 
1671
1718
  layer_variables = (
1672
1719
  self.layer_types,
1720
+ self.skip_combines,
1673
1721
  self.layers,
1674
1722
  self.layer_dropouts
1675
1723
  )
@@ -1680,11 +1728,24 @@ class AttentionLayers(Module):
1680
1728
 
1681
1729
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1682
1730
 
1731
+ # store all hiddens for skips
1732
+
1733
+ skip_hiddens = []
1734
+
1683
1735
  # go through the attention and feedforward layers
1684
1736
 
1685
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1737
+ for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1686
1738
  is_last = ind == (len(self.layers) - 1)
1687
1739
 
1740
+ # handle skip connections
1741
+
1742
+ skip_hiddens.append(x)
1743
+
1744
+ if exists(skip_combine):
1745
+ x = skip_combine(x, skip_hiddens)
1746
+
1747
+ # layer dropout
1748
+
1688
1749
  if self.training and layer_dropout > 0. and random() < layer_dropout:
1689
1750
  continue
1690
1751
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.34.0
3
+ Version: 1.35.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
2
+ x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
3
3
  x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=hs9j-lHukVGYLlpbBhn4CZhSzI7s0x6bYtEhCc33ftE,78680
8
+ x_transformers/x_transformers.py,sha256=2oQoQs7RMbFrVdMeOddy6yq1MhJxnficjORmMWBjjPo,80593
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.34.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.34.0.dist-info/METADATA,sha256=aTRBJepYjojT5TFi8W2oK4j7daQGRQaWwj2HHnnwDCQ,661
13
- x_transformers-1.34.0.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
14
- x_transformers-1.34.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.34.0.dist-info/RECORD,,
11
+ x_transformers-1.35.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.35.0.dist-info/METADATA,sha256=D32aQ96BsP6BXjikkuZUHc77sO6thZVO9cI_xFgLQF0,661
13
+ x_transformers-1.35.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
14
+ x_transformers-1.35.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.35.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5