x-transformers 1.42.15__tar.gz → 1.42.17__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.15/x_transformers.egg-info → x_transformers-1.42.17}/PKG-INFO +1 -1
- {x_transformers-1.42.15 → x_transformers-1.42.17}/setup.py +1 -1
- {x_transformers-1.42.15 → x_transformers-1.42.17}/tests/test_x_transformers.py +6 -2
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/x_transformers.py +5 -3
- {x_transformers-1.42.15 → x_transformers-1.42.17/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.15 → x_transformers-1.42.17}/LICENSE +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/README.md +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/setup.cfg +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers.egg-info/top_level.txt +0 -0
@@ -352,7 +352,10 @@ def test_value_residual(
|
|
352
352
|
|
353
353
|
model(x)
|
354
354
|
|
355
|
-
|
355
|
+
@pytest.mark.parametrize('has_num_mem_kv', (False, True))
|
356
|
+
def test_forgetting_transformer(
|
357
|
+
has_num_mem_kv: bool
|
358
|
+
):
|
356
359
|
|
357
360
|
model = TransformerWrapper(
|
358
361
|
num_tokens = 20000,
|
@@ -361,7 +364,8 @@ def test_forgetting_transformer():
|
|
361
364
|
dim = 128,
|
362
365
|
depth = 6,
|
363
366
|
heads = 8,
|
364
|
-
|
367
|
+
attn_num_mem_kv = 1 if has_num_mem_kv else 0,
|
368
|
+
attn_data_dependent_alibi = True
|
365
369
|
)
|
366
370
|
)
|
367
371
|
|
@@ -1235,9 +1235,9 @@ class Attention(Module):
|
|
1235
1235
|
# maybe learned value residual mixer per token
|
1236
1236
|
|
1237
1237
|
self.to_value_residual_mix = nn.Sequential(
|
1238
|
-
nn.Linear(dim,
|
1238
|
+
nn.Linear(dim, heads),
|
1239
1239
|
nn.Sigmoid(),
|
1240
|
-
Rearrange('b n
|
1240
|
+
Rearrange('b n h -> b h n 1')
|
1241
1241
|
) if learned_value_residual_mix else always(0.5)
|
1242
1242
|
|
1243
1243
|
# attention on attention
|
@@ -1428,13 +1428,15 @@ class Attention(Module):
|
|
1428
1428
|
else:
|
1429
1429
|
attn_bias = rel_pos(i, j)
|
1430
1430
|
|
1431
|
-
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0)
|
1431
|
+
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0)) # handle memory key / values
|
1432
1432
|
|
1433
1433
|
# prepare data dependent alibi from forgetting transformers paper, if needed
|
1434
1434
|
|
1435
1435
|
if exists(self.data_dependent_alibi):
|
1436
1436
|
attn_bias = self.data_dependent_alibi(x)
|
1437
1437
|
|
1438
|
+
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
1439
|
+
|
1438
1440
|
# attention is all we need
|
1439
1441
|
|
1440
1442
|
out, intermediates = self.attend(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.15 → x_transformers-1.42.17}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|