x-transformers 1.42.19__tar.gz → 1.42.21__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.19/x_transformers.egg-info → x_transformers-1.42.21}/PKG-INFO +1 -1
- {x_transformers-1.42.19 → x_transformers-1.42.21}/setup.py +1 -1
- {x_transformers-1.42.19 → x_transformers-1.42.21}/tests/test_x_transformers.py +28 -3
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/attend.py +1 -1
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/x_transformers.py +6 -3
- {x_transformers-1.42.19 → x_transformers-1.42.21/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.19 → x_transformers-1.42.21}/LICENSE +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/README.md +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/setup.cfg +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/top_level.txt +0 -0
@@ -388,7 +388,8 @@ def test_neo_mlp():
|
|
388
388
|
out = mlp(x)
|
389
389
|
assert out.shape == (3, 7)
|
390
390
|
|
391
|
-
|
391
|
+
@pytest.mark.parametrize('flash', (True, False))
|
392
|
+
def test_custom_alibi(flash: bool):
|
392
393
|
|
393
394
|
model = TransformerWrapper(
|
394
395
|
num_tokens = 20_000,
|
@@ -397,7 +398,8 @@ def test_custom_alibi():
|
|
397
398
|
dim = 512,
|
398
399
|
depth = 2,
|
399
400
|
heads = 8,
|
400
|
-
alibi_pos_bias = True
|
401
|
+
alibi_pos_bias = True,
|
402
|
+
attn_flash = flash
|
401
403
|
)
|
402
404
|
)
|
403
405
|
|
@@ -407,8 +409,30 @@ def test_custom_alibi():
|
|
407
409
|
|
408
410
|
logits = model(x, pos = pos)
|
409
411
|
|
410
|
-
def
|
412
|
+
def test_custom_rotary_pos_emb():
|
413
|
+
from einops import repeat
|
411
414
|
|
415
|
+
model = TransformerWrapper(
|
416
|
+
num_tokens = 20_000,
|
417
|
+
max_seq_len = 1024,
|
418
|
+
attn_layers = Decoder(
|
419
|
+
dim = 512,
|
420
|
+
depth = 2,
|
421
|
+
heads = 8,
|
422
|
+
rotary_pos_emb = True
|
423
|
+
)
|
424
|
+
)
|
425
|
+
|
426
|
+
x = torch.randint(0, 20000, (4, 4))
|
427
|
+
|
428
|
+
pos = repeat(torch.arange(0, 4), "n -> b n", b=4)
|
429
|
+
|
430
|
+
logits1 = model(x, pos = pos)
|
431
|
+
logits2 = model(x)
|
432
|
+
assert torch.allclose(logits1, logits2)
|
433
|
+
|
434
|
+
@pytest.mark.parametrize('flash', (True, False))
|
435
|
+
def test_custom_alibi_across_heads(flash: bool):
|
412
436
|
model = Decoder(
|
413
437
|
dim = 512,
|
414
438
|
depth = 2,
|
@@ -417,6 +441,7 @@ def test_custom_alibi_across_heads():
|
|
417
441
|
rel_pos_kwargs = dict(
|
418
442
|
slopes = [1, 1]
|
419
443
|
),
|
444
|
+
attn_flash = flash
|
420
445
|
)
|
421
446
|
|
422
447
|
x = torch.randn(2, 4, 512)
|
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
|
|
655
655
|
def forward(self, t):
|
656
656
|
max_pos = t.max() + 1
|
657
657
|
|
658
|
-
|
658
|
+
if t.ndim == 1:
|
659
|
+
t = rearrange(t, 'n -> 1 n')
|
660
|
+
|
661
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
659
662
|
freqs = torch.stack((freqs, freqs), dim = -1)
|
660
663
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
661
664
|
|
@@ -679,8 +682,8 @@ def rotate_half(x):
|
|
679
682
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
680
683
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
681
684
|
|
682
|
-
freqs = freqs[-seq_len:, :]
|
683
|
-
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
685
|
+
freqs = freqs[:, -seq_len:, :]
|
686
|
+
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
684
687
|
|
685
688
|
if t.ndim == 4 and freqs.ndim == 3:
|
686
689
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|