x-transformers 1.27.20__tar.gz → 1.27.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.27.20/x_transformers.egg-info → x_transformers-1.27.22}/PKG-INFO +1 -1
- {x_transformers-1.27.20 → x_transformers-1.27.22}/README.md +1 -2
- {x_transformers-1.27.20 → x_transformers-1.27.22}/setup.py +1 -1
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/xval.py +22 -7
- {x_transformers-1.27.20 → x_transformers-1.27.22/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.27.20 → x_transformers-1.27.22}/LICENSE +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/setup.cfg +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/__init__.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/attend.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/continuous.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/dpo.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/x_transformers.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1437,11 +1437,10 @@ model = XValAutoregressiveWrapper(model)
|
|
1437
1437
|
|
1438
1438
|
ids = torch.randint(0, 4, (1, 777))
|
1439
1439
|
nums = torch.randn(1, 777)
|
1440
|
-
mask = torch.ones(1, 777).bool()
|
1441
1440
|
|
1442
1441
|
# train on a lot of data above
|
1443
1442
|
|
1444
|
-
loss = model(ids, nums
|
1443
|
+
loss = model(ids, nums)
|
1445
1444
|
loss.backward()
|
1446
1445
|
|
1447
1446
|
# then generate
|
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
260
260
|
inp, target = x[:, :-1], x[:, 1:]
|
261
261
|
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
|
262
262
|
|
263
|
+
# ignore index
|
264
|
+
|
265
|
+
target_mask = target != self.ignore_index
|
266
|
+
|
267
|
+
# key padding mask
|
268
|
+
|
263
269
|
mask = kwargs.get('mask', None)
|
264
|
-
if exists(mask)
|
265
|
-
|
266
|
-
|
270
|
+
if exists(mask):
|
271
|
+
target_mask &= mask
|
272
|
+
|
273
|
+
if mask.shape[1] == x.shape[1]:
|
274
|
+
mask = mask[:, :-1]
|
275
|
+
kwargs['mask'] = mask
|
267
276
|
|
268
277
|
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
|
269
278
|
|
@@ -271,17 +280,23 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
271
280
|
|
272
281
|
cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
|
273
282
|
|
274
|
-
|
283
|
+
# protect against nan in `x_num` input tensor
|
284
|
+
|
285
|
+
target_is_number_mask = target == self.net.numerical_token_id
|
286
|
+
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
|
287
|
+
|
288
|
+
# numerical mse loss
|
275
289
|
|
276
290
|
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
|
277
291
|
|
278
292
|
numerical_mse_loss = numerical_mse_loss * target_mask
|
293
|
+
numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
|
279
294
|
|
280
|
-
|
295
|
+
# combine losses
|
281
296
|
|
282
|
-
|
283
|
-
loss = loss[mask]
|
297
|
+
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
|
284
298
|
|
299
|
+
loss = loss[target_mask]
|
285
300
|
loss = loss.mean()
|
286
301
|
|
287
302
|
if not return_loss_breakdown:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
{x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
{x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|