x-transformers 1.27.19__py3-none-any.whl → 1.27.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +9 -10
- {x_transformers-1.27.19.dist-info → x_transformers-1.27.20.dist-info}/METADATA +1 -1
- {x_transformers-1.27.19.dist-info → x_transformers-1.27.20.dist-info}/RECORD +6 -6
- {x_transformers-1.27.19.dist-info → x_transformers-1.27.20.dist-info}/WHEEL +1 -1
- {x_transformers-1.27.19.dist-info → x_transformers-1.27.20.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.19.dist-info → x_transformers-1.27.20.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -489,16 +489,6 @@ class Scale(nn.Module):
|
|
489
489
|
|
490
490
|
return (scale_fn(out[0]), *out[1:])
|
491
491
|
|
492
|
-
class ScaleNorm(nn.Module):
|
493
|
-
def __init__(self, dim, eps = 1e-5):
|
494
|
-
super().__init__()
|
495
|
-
self.eps = eps
|
496
|
-
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
|
497
|
-
|
498
|
-
def forward(self, x):
|
499
|
-
norm = torch.norm(x, dim = -1, keepdim = True)
|
500
|
-
return x / norm.clamp(min = self.eps) * self.g
|
501
|
-
|
502
492
|
class LayerNorm(nn.Module):
|
503
493
|
def __init__(self, dim):
|
504
494
|
"""
|
@@ -514,6 +504,15 @@ class LayerNorm(nn.Module):
|
|
514
504
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
515
505
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
516
506
|
|
507
|
+
class ScaleNorm(nn.Module):
|
508
|
+
def __init__(self, dim):
|
509
|
+
super().__init__()
|
510
|
+
self.scale = dim ** 0.5
|
511
|
+
self.g = nn.Parameter(torch.ones(1))
|
512
|
+
|
513
|
+
def forward(self, x):
|
514
|
+
return F.normalize(x, dim = -1) * self.scale * self.g
|
515
|
+
|
517
516
|
class RMSNorm(nn.Module):
|
518
517
|
def __init__(self, dim):
|
519
518
|
super().__init__()
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
10
|
+
x_transformers-1.27.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.20.dist-info/METADATA,sha256=KfY4iK3HVlgVafNYc7GDUarZLvEacMDqzMACQBAZndU,662
|
12
|
+
x_transformers-1.27.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.27.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|