x-transformers 1.40.3__tar.gz → 1.40.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.40.3/x_transformers.egg-info → x_transformers-1.40.5}/PKG-INFO +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.5}/README.md +23 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/setup.py +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.5}/tests/test_x_transformers.py +18 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/x_transformers.py +20 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.5}/LICENSE +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/setup.cfg +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/__init__.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/attend.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/continuous.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/dpo.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/xval.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2287,4 +2287,27 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2287
2287
|
}
|
2288
2288
|
```
|
2289
2289
|
|
2290
|
+
```bibtex
|
2291
|
+
@article{Bai2019DeepEM,
|
2292
|
+
title = {Deep Equilibrium Models},
|
2293
|
+
author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
|
2294
|
+
journal = {ArXiv},
|
2295
|
+
year = {2019},
|
2296
|
+
volume = {abs/1909.01377},
|
2297
|
+
url = {https://api.semanticscholar.org/CorpusID:202539738}
|
2298
|
+
}
|
2299
|
+
```
|
2300
|
+
|
2301
|
+
```bibtex
|
2302
|
+
@article{Wu2021MuseMorphoseFA,
|
2303
|
+
title = {MuseMorphose: Full-Song and Fine-Grained Piano Music Style Transfer With One Transformer VAE},
|
2304
|
+
author = {Shih-Lun Wu and Yi-Hsuan Yang},
|
2305
|
+
journal = {IEEE/ACM Transactions on Audio, Speech, and Language Processing},
|
2306
|
+
year = {2021},
|
2307
|
+
volume = {31},
|
2308
|
+
pages = {1953-1967},
|
2309
|
+
url = {https://api.semanticscholar.org/CorpusID:234338162}
|
2310
|
+
}
|
2311
|
+
```
|
2312
|
+
|
2290
2313
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -298,3 +298,21 @@ def test_l2_distance(attn_one_kv_head):
|
|
298
298
|
x = torch.randint(0, 256, (1, 1024))
|
299
299
|
|
300
300
|
model(x)
|
301
|
+
|
302
|
+
def test_reinject_input():
|
303
|
+
|
304
|
+
model = TransformerWrapper(
|
305
|
+
num_tokens = 20000,
|
306
|
+
max_seq_len = 1024,
|
307
|
+
recycling = True,
|
308
|
+
attn_layers = Decoder(
|
309
|
+
dim = 512,
|
310
|
+
depth = 12,
|
311
|
+
heads = 8,
|
312
|
+
reinject_input = True
|
313
|
+
)
|
314
|
+
)
|
315
|
+
|
316
|
+
x = torch.randint(0, 256, (1, 12))
|
317
|
+
|
318
|
+
model(x) # (1, 1024, 20000)
|
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
|
|
1353
1353
|
use_layerscale = False,
|
1354
1354
|
layerscale_init_value = 0.,
|
1355
1355
|
unet_skips = False,
|
1356
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1356
1357
|
**kwargs
|
1357
1358
|
):
|
1358
1359
|
super().__init__()
|
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
|
|
1582
1583
|
|
1583
1584
|
self.skip_combines = ModuleList([])
|
1584
1585
|
|
1586
|
+
# whether there is reinjection of input at every layer
|
1587
|
+
|
1588
|
+
self.reinject_input = reinject_input
|
1589
|
+
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
|
+
|
1585
1591
|
# iterate and construct layers
|
1586
1592
|
|
1587
1593
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1667,6 +1673,7 @@ class AttentionLayers(Module):
|
|
1667
1673
|
rotary_pos_emb = None,
|
1668
1674
|
attn_bias = None,
|
1669
1675
|
condition = None,
|
1676
|
+
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
1670
1677
|
layers_execute_order: tuple[int, ...] | None = None
|
1671
1678
|
):
|
1672
1679
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
@@ -1766,6 +1773,16 @@ class AttentionLayers(Module):
|
|
1766
1773
|
|
1767
1774
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1768
1775
|
|
1776
|
+
# derived input for reinjection if needed
|
1777
|
+
|
1778
|
+
if self.reinject_input:
|
1779
|
+
assert not exists(in_attn_cond)
|
1780
|
+
inp_inject = self.reinject_input_proj(x)
|
1781
|
+
|
1782
|
+
elif exists(in_attn_cond):
|
1783
|
+
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
1784
|
+
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
1785
|
+
|
1769
1786
|
# store all hiddens for skips
|
1770
1787
|
|
1771
1788
|
skip_hiddens = []
|
@@ -1810,6 +1827,9 @@ class AttentionLayers(Module):
|
|
1810
1827
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1811
1828
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1812
1829
|
|
1830
|
+
if self.reinject_input:
|
1831
|
+
x = x + inp_inject
|
1832
|
+
|
1813
1833
|
if exists(pre_norm):
|
1814
1834
|
x = pre_norm(x)
|
1815
1835
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|