x-transformers 1.42.20__tar.gz → 1.42.21__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.20/x_transformers.egg-info → x_transformers-1.42.21}/PKG-INFO +1 -1
- {x_transformers-1.42.20 → x_transformers-1.42.21}/setup.py +1 -1
- {x_transformers-1.42.20 → x_transformers-1.42.21}/tests/test_x_transformers.py +22 -1
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/x_transformers.py +6 -3
- {x_transformers-1.42.20 → x_transformers-1.42.21/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.20 → x_transformers-1.42.21}/LICENSE +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/README.md +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/setup.cfg +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/top_level.txt +0 -0
@@ -409,9 +409,30 @@ def test_custom_alibi(flash: bool):
|
|
409
409
|
|
410
410
|
logits = model(x, pos = pos)
|
411
411
|
|
412
|
+
def test_custom_rotary_pos_emb():
|
413
|
+
from einops import repeat
|
414
|
+
|
415
|
+
model = TransformerWrapper(
|
416
|
+
num_tokens = 20_000,
|
417
|
+
max_seq_len = 1024,
|
418
|
+
attn_layers = Decoder(
|
419
|
+
dim = 512,
|
420
|
+
depth = 2,
|
421
|
+
heads = 8,
|
422
|
+
rotary_pos_emb = True
|
423
|
+
)
|
424
|
+
)
|
425
|
+
|
426
|
+
x = torch.randint(0, 20000, (4, 4))
|
427
|
+
|
428
|
+
pos = repeat(torch.arange(0, 4), "n -> b n", b=4)
|
429
|
+
|
430
|
+
logits1 = model(x, pos = pos)
|
431
|
+
logits2 = model(x)
|
432
|
+
assert torch.allclose(logits1, logits2)
|
433
|
+
|
412
434
|
@pytest.mark.parametrize('flash', (True, False))
|
413
435
|
def test_custom_alibi_across_heads(flash: bool):
|
414
|
-
|
415
436
|
model = Decoder(
|
416
437
|
dim = 512,
|
417
438
|
depth = 2,
|
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
|
|
655
655
|
def forward(self, t):
|
656
656
|
max_pos = t.max() + 1
|
657
657
|
|
658
|
-
|
658
|
+
if t.ndim == 1:
|
659
|
+
t = rearrange(t, 'n -> 1 n')
|
660
|
+
|
661
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
659
662
|
freqs = torch.stack((freqs, freqs), dim = -1)
|
660
663
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
661
664
|
|
@@ -679,8 +682,8 @@ def rotate_half(x):
|
|
679
682
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
680
683
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
681
684
|
|
682
|
-
freqs = freqs[-seq_len:, :]
|
683
|
-
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
685
|
+
freqs = freqs[:, -seq_len:, :]
|
686
|
+
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
684
687
|
|
685
688
|
if t.ndim == 4 and freqs.ndim == 3:
|
686
689
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|