x-transformers 1.27.11__py3-none-any.whl → 1.27.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +4 -9
- {x_transformers-1.27.11.dist-info → x_transformers-1.27.14.dist-info}/METADATA +1 -1
- {x_transformers-1.27.11.dist-info → x_transformers-1.27.14.dist-info}/RECORD +6 -6
- {x_transformers-1.27.11.dist-info → x_transformers-1.27.14.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.11.dist-info → x_transformers-1.27.14.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.11.dist-info → x_transformers-1.27.14.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -438,19 +438,15 @@ class RotaryEmbedding(nn.Module):
|
|
438
438
|
|
439
439
|
@autocast(enabled = False)
|
440
440
|
def forward(self, t):
|
441
|
-
|
441
|
+
max_pos = t.max()+1
|
442
442
|
|
443
|
-
|
444
|
-
|
445
|
-
t = t / self.interpolation_factor
|
446
|
-
|
447
|
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
443
|
+
freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
448
444
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
449
445
|
|
450
446
|
if not exists(self.scale):
|
451
447
|
return freqs, 1.
|
452
448
|
|
453
|
-
power = (
|
449
|
+
power = (t - (max_pos // 2)) / self.scale_base
|
454
450
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
455
451
|
scale = torch.cat((scale, scale), dim = -1)
|
456
452
|
|
@@ -466,6 +462,7 @@ def rotate_half(x):
|
|
466
462
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
467
463
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
468
464
|
freqs = freqs[-seq_len:, :]
|
465
|
+
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
469
466
|
|
470
467
|
if t.ndim == 4 and freqs.ndim == 3:
|
471
468
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
@@ -1267,8 +1264,6 @@ class AttentionLayers(nn.Module):
|
|
1267
1264
|
# rotary positions
|
1268
1265
|
|
1269
1266
|
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
|
1270
|
-
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
1271
|
-
|
1272
1267
|
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1273
1268
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1274
1269
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
|
5
5
|
x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=3caIQMDP2pxVuAA-CdEteUqX9ikNSanrmzKjkvzogjE,63619
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
10
|
+
x_transformers-1.27.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.14.dist-info/METADATA,sha256=fXXkd4baN2z6pg5aWlMy-6Jpwb6PtKH-Bntnr6EdYWg,662
|
12
|
+
x_transformers-1.27.14.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
x_transformers-1.27.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|