x-transformers 1.40.7__py3-none-any.whl → 1.40.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +23 -5
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.8.dist-info}/METADATA +1 -1
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.8.dist-info}/RECORD +6 -6
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.8.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1803,7 +1803,10 @@ class AttentionLayers(Module):
|
|
1803
1803
|
|
1804
1804
|
skip_hiddens = []
|
1805
1805
|
|
1806
|
+
# for value residuals
|
1807
|
+
|
1806
1808
|
first_self_attn_inter = None
|
1809
|
+
first_cross_attn_inter = None
|
1807
1810
|
|
1808
1811
|
# go through the attention and feedforward layers
|
1809
1812
|
|
@@ -1856,20 +1859,35 @@ class AttentionLayers(Module):
|
|
1856
1859
|
|
1857
1860
|
block = partial(block, **block_forward_kwargs)
|
1858
1861
|
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
+
# handle maybe value residuals
|
1863
|
+
|
1864
|
+
maybe_self_attn_value_residual = None
|
1865
|
+
maybe_cross_attn_value_residual = None
|
1866
|
+
|
1867
|
+
if self.add_value_residual:
|
1868
|
+
if exists(first_self_attn_inter):
|
1869
|
+
maybe_self_attn_value_residual = first_self_attn_inter.values
|
1870
|
+
|
1871
|
+
if exists(first_cross_attn_inter):
|
1872
|
+
maybe_cross_attn_value_residual = first_cross_attn_inter.values
|
1873
|
+
|
1874
|
+
# forward depending on layer type
|
1862
1875
|
|
1863
1876
|
if layer_type == 'a':
|
1864
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual =
|
1877
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
1865
1878
|
elif layer_type == 'c':
|
1866
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1879
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
1867
1880
|
elif layer_type == 'f':
|
1868
1881
|
out = block(x)
|
1869
1882
|
|
1883
|
+
# store first self or cross attention intermediate for value residual
|
1884
|
+
|
1870
1885
|
if not exists(first_self_attn_inter) and layer_type == 'a':
|
1871
1886
|
first_self_attn_inter = inter
|
1872
1887
|
|
1888
|
+
if not exists(first_cross_attn_inter) and layer_type == 'c':
|
1889
|
+
first_cross_attn_inter = inter
|
1890
|
+
|
1873
1891
|
if self.resi_dual:
|
1874
1892
|
outer_residual = outer_residual + out * self.resi_dual_scale
|
1875
1893
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=UyJ3-XoanCHOUZLoyp29jt-yNWbgc7pFqpfjOGFFEPY,87367
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.8.dist-info/METADATA,sha256=epku2YLENdljdFU7ge7hhUDXMujquHyb0jhoDSv9yFk,661
|
13
|
+
x_transformers-1.40.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|