x-transformers 1.34.0__py3-none-any.whl → 1.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +21 -3
- x_transformers/x_transformers.py +62 -1
- {x_transformers-1.34.0.dist-info → x_transformers-1.35.0.dist-info}/METADATA +1 -1
- {x_transformers-1.34.0.dist-info → x_transformers-1.35.0.dist-info}/RECORD +7 -7
- {x_transformers-1.34.0.dist-info → x_transformers-1.35.0.dist-info}/WHEEL +1 -1
- {x_transformers-1.34.0.dist-info → x_transformers-1.35.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.34.0.dist-info → x_transformers-1.35.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -138,9 +138,27 @@ class Attend(Module):
|
|
138
138
|
# flash attention
|
139
139
|
|
140
140
|
self.flash = flash
|
141
|
-
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
142
141
|
|
143
|
-
|
142
|
+
torch_version = version.parse(torch.__version__)
|
143
|
+
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
144
|
+
|
145
|
+
# torch 2.3 uses new backend and context manager
|
146
|
+
|
147
|
+
if torch_version >= version.parse('2.3'):
|
148
|
+
from torch.nn.attention import SDPBackend
|
149
|
+
|
150
|
+
str_to_backend = dict(
|
151
|
+
enable_flash = SDPBackend.FLASH_ATTENTION,
|
152
|
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
153
|
+
enable_math = SDPBackend.MATH,
|
154
|
+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
155
|
+
)
|
156
|
+
|
157
|
+
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
158
|
+
|
159
|
+
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
160
|
+
else:
|
161
|
+
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
144
162
|
|
145
163
|
def flash_attn(
|
146
164
|
self,
|
@@ -231,7 +249,7 @@ class Attend(Module):
|
|
231
249
|
|
232
250
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
233
251
|
|
234
|
-
with
|
252
|
+
with self.sdp_context_manager():
|
235
253
|
out = F.scaled_dot_product_attention(
|
236
254
|
q, k, v,
|
237
255
|
attn_mask = mask,
|
x_transformers/x_transformers.py
CHANGED
@@ -810,6 +810,19 @@ class AdaptiveLayerScale(Module):
|
|
810
810
|
out, *rest = out
|
811
811
|
return out * gamma, *rest
|
812
812
|
|
813
|
+
# skip connection combining
|
814
|
+
|
815
|
+
class ConcatCombine(Module):
|
816
|
+
def __init__(self, dim, prev_layer_ind):
|
817
|
+
super().__init__()
|
818
|
+
self.prev_layer_ind = prev_layer_ind
|
819
|
+
self.combine = nn.Linear(dim * 2, dim, bias = False)
|
820
|
+
|
821
|
+
def forward(self, x, prev_layers: list[Tensor]):
|
822
|
+
skip = prev_layers[self.prev_layer_ind]
|
823
|
+
concatted_skip = torch.cat((skip, x), dim = -1)
|
824
|
+
return self.combine(concatted_skip)
|
825
|
+
|
813
826
|
# feedforward
|
814
827
|
|
815
828
|
class GLU(Module):
|
@@ -1307,6 +1320,7 @@ class AttentionLayers(Module):
|
|
1307
1320
|
disable_abs_pos_emb = None,
|
1308
1321
|
use_layerscale = False,
|
1309
1322
|
layerscale_init_value = 0.,
|
1323
|
+
unet_skips = False,
|
1310
1324
|
**kwargs
|
1311
1325
|
):
|
1312
1326
|
super().__init__()
|
@@ -1468,6 +1482,8 @@ class AttentionLayers(Module):
|
|
1468
1482
|
|
1469
1483
|
# calculate layer block order
|
1470
1484
|
|
1485
|
+
len_default_block = 1
|
1486
|
+
|
1471
1487
|
if exists(custom_layers):
|
1472
1488
|
layer_types = custom_layers
|
1473
1489
|
elif exists(par_ratio):
|
@@ -1487,6 +1503,7 @@ class AttentionLayers(Module):
|
|
1487
1503
|
else:
|
1488
1504
|
assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
|
1489
1505
|
layer_types = default_block * depth
|
1506
|
+
len_default_block = len(default_block)
|
1490
1507
|
|
1491
1508
|
self.layer_types = layer_types
|
1492
1509
|
self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
|
@@ -1522,11 +1539,31 @@ class AttentionLayers(Module):
|
|
1522
1539
|
|
1523
1540
|
self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
|
1524
1541
|
|
1542
|
+
# whether unet or not
|
1543
|
+
|
1544
|
+
self.unet_skips = unet_skips
|
1545
|
+
num_skips = self.depth // len_default_block
|
1546
|
+
|
1547
|
+
assert not (unet_skips and num_skips == 0), 'must have depth of at least 2 for unet skip connections'
|
1548
|
+
|
1549
|
+
skip_indices = [i * len_default_block for i in range(num_skips)]
|
1550
|
+
|
1551
|
+
self.skip_combines = ModuleList([])
|
1552
|
+
|
1525
1553
|
# iterate and construct layers
|
1526
1554
|
|
1527
1555
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
1556
|
+
|
1557
|
+
# `ind` is the index of each module - attention, feedforward, cross attention
|
1558
|
+
# but `block_ind` refers to the typical enumeration of a transformer block (attn + ff + [optional] cross attn)
|
1559
|
+
|
1560
|
+
block_begin = divisible_by(ind, len_default_block)
|
1561
|
+
block_ind = ind // len_default_block
|
1562
|
+
|
1528
1563
|
is_last_layer = ind == (len(self.layer_types) - 1)
|
1529
1564
|
|
1565
|
+
# attention, cross attention, feedforward
|
1566
|
+
|
1530
1567
|
if layer_type == 'a':
|
1531
1568
|
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
|
1532
1569
|
elif layer_type == 'c':
|
@@ -1548,6 +1585,14 @@ class AttentionLayers(Module):
|
|
1548
1585
|
residual_fn = GRUGating if gate_residual else Residual
|
1549
1586
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1550
1587
|
|
1588
|
+
# handle unet skip connection
|
1589
|
+
|
1590
|
+
skip_combine = None
|
1591
|
+
is_latter_half = block_begin and block_ind >= (self.depth / 2)
|
1592
|
+
|
1593
|
+
if self.unet_skips and is_latter_half:
|
1594
|
+
skip_combine = ConcatCombine(dim, skip_indices.pop())
|
1595
|
+
|
1551
1596
|
# all normalizations of the layer
|
1552
1597
|
|
1553
1598
|
pre_branch_norm = norm_fn() if pre_norm else None
|
@@ -1560,6 +1605,8 @@ class AttentionLayers(Module):
|
|
1560
1605
|
post_main_norm
|
1561
1606
|
])
|
1562
1607
|
|
1608
|
+
self.skip_combines.append(skip_combine)
|
1609
|
+
|
1563
1610
|
self.layers.append(ModuleList([
|
1564
1611
|
norms,
|
1565
1612
|
layer,
|
@@ -1670,6 +1717,7 @@ class AttentionLayers(Module):
|
|
1670
1717
|
|
1671
1718
|
layer_variables = (
|
1672
1719
|
self.layer_types,
|
1720
|
+
self.skip_combines,
|
1673
1721
|
self.layers,
|
1674
1722
|
self.layer_dropouts
|
1675
1723
|
)
|
@@ -1680,11 +1728,24 @@ class AttentionLayers(Module):
|
|
1680
1728
|
|
1681
1729
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1682
1730
|
|
1731
|
+
# store all hiddens for skips
|
1732
|
+
|
1733
|
+
skip_hiddens = []
|
1734
|
+
|
1683
1735
|
# go through the attention and feedforward layers
|
1684
1736
|
|
1685
|
-
for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
1737
|
+
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
1686
1738
|
is_last = ind == (len(self.layers) - 1)
|
1687
1739
|
|
1740
|
+
# handle skip connections
|
1741
|
+
|
1742
|
+
skip_hiddens.append(x)
|
1743
|
+
|
1744
|
+
if exists(skip_combine):
|
1745
|
+
x = skip_combine(x, skip_hiddens)
|
1746
|
+
|
1747
|
+
# layer dropout
|
1748
|
+
|
1688
1749
|
if self.training and layer_dropout > 0. and random() < layer_dropout:
|
1689
1750
|
continue
|
1690
1751
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=2oQoQs7RMbFrVdMeOddy6yq1MhJxnficjORmMWBjjPo,80593
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.35.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.35.0.dist-info/METADATA,sha256=D32aQ96BsP6BXjikkuZUHc77sO6thZVO9cI_xFgLQF0,661
|
13
|
+
x_transformers-1.35.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
14
|
+
x_transformers-1.35.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.35.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|