x-transformers 1.27.19__tar.gz → 1.27.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x-transformers-1.27.19/x_transformers.egg-info → x_transformers-1.27.21}/PKG-INFO +1 -1
- {x-transformers-1.27.19 → x_transformers-1.27.21}/README.md +1 -2
- {x-transformers-1.27.19 → x_transformers-1.27.21}/setup.py +1 -1
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/x_transformers.py +9 -10
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/xval.py +9 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21/x_transformers.egg-info}/PKG-INFO +1 -1
- {x-transformers-1.27.19 → x_transformers-1.27.21}/LICENSE +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/setup.cfg +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/__init__.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/attend.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/autoregressive_wrapper.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/continuous.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/dpo.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers.egg-info/requires.txt +0 -0
- {x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1437,11 +1437,10 @@ model = XValAutoregressiveWrapper(model)
|
|
1437
1437
|
|
1438
1438
|
ids = torch.randint(0, 4, (1, 777))
|
1439
1439
|
nums = torch.randn(1, 777)
|
1440
|
-
mask = torch.ones(1, 777).bool()
|
1441
1440
|
|
1442
1441
|
# train on a lot of data above
|
1443
1442
|
|
1444
|
-
loss = model(ids, nums
|
1443
|
+
loss = model(ids, nums)
|
1445
1444
|
loss.backward()
|
1446
1445
|
|
1447
1446
|
# then generate
|
@@ -489,16 +489,6 @@ class Scale(nn.Module):
|
|
489
489
|
|
490
490
|
return (scale_fn(out[0]), *out[1:])
|
491
491
|
|
492
|
-
class ScaleNorm(nn.Module):
|
493
|
-
def __init__(self, dim, eps = 1e-5):
|
494
|
-
super().__init__()
|
495
|
-
self.eps = eps
|
496
|
-
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
|
497
|
-
|
498
|
-
def forward(self, x):
|
499
|
-
norm = torch.norm(x, dim = -1, keepdim = True)
|
500
|
-
return x / norm.clamp(min = self.eps) * self.g
|
501
|
-
|
502
492
|
class LayerNorm(nn.Module):
|
503
493
|
def __init__(self, dim):
|
504
494
|
"""
|
@@ -514,6 +504,15 @@ class LayerNorm(nn.Module):
|
|
514
504
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
515
505
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
516
506
|
|
507
|
+
class ScaleNorm(nn.Module):
|
508
|
+
def __init__(self, dim):
|
509
|
+
super().__init__()
|
510
|
+
self.scale = dim ** 0.5
|
511
|
+
self.g = nn.Parameter(torch.ones(1))
|
512
|
+
|
513
|
+
def forward(self, x):
|
514
|
+
return F.normalize(x, dim = -1) * self.scale * self.g
|
515
|
+
|
517
516
|
class RMSNorm(nn.Module):
|
518
517
|
def __init__(self, dim):
|
519
518
|
super().__init__()
|
@@ -271,8 +271,17 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
271
271
|
|
272
272
|
cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
|
273
273
|
|
274
|
+
# protect against nan in `x_num` input tensor
|
275
|
+
|
276
|
+
target_is_number_mask = target == self.net.numerical_token_id
|
277
|
+
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
|
278
|
+
|
279
|
+
# ignore index
|
280
|
+
|
274
281
|
target_mask = target != self.ignore_index
|
275
282
|
|
283
|
+
# numerical mse loss
|
284
|
+
|
276
285
|
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
|
277
286
|
|
278
287
|
numerical_mse_loss = numerical_mse_loss * target_mask
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
{x-transformers-1.27.19 → x_transformers-1.27.21}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|