x-transformers 1.42.9__tar.gz → 1.42.11__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.9/x_transformers.egg-info → x_transformers-1.42.11}/PKG-INFO +1 -1
- {x_transformers-1.42.9 → x_transformers-1.42.11}/setup.py +1 -1
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/autoregressive_wrapper.py +1 -1
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/continuous.py +20 -3
- {x_transformers-1.42.9 → x_transformers-1.42.11/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.9 → x_transformers-1.42.11}/LICENSE +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/README.md +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/setup.cfg +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/tests/test_x_transformers.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/x_transformers.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers.egg-info/top_level.txt +0 -0
@@ -48,7 +48,7 @@ def align_right(t, lens, pad_id = 0):
|
|
48
48
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
49
49
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
50
50
|
|
51
|
-
t = F.pad(t, (max_pad_len, 0), value =
|
51
|
+
t = F.pad(t, (max_pad_len, 0), value = pad_id)
|
52
52
|
offset = max_pad_len - pad_lens
|
53
53
|
|
54
54
|
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
@@ -2,7 +2,8 @@ import torch
|
|
2
2
|
from torch import nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
|
5
|
-
|
5
|
+
import einx
|
6
|
+
from einops import reduce, pack, repeat, unpack
|
6
7
|
|
7
8
|
from x_transformers.x_transformers import (
|
8
9
|
AttentionLayers,
|
@@ -24,6 +25,15 @@ def default(val, d):
|
|
24
25
|
return val
|
25
26
|
return d() if callable(d) else d
|
26
27
|
|
28
|
+
def masked_mean(t, mask):
|
29
|
+
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
30
|
+
|
31
|
+
num = reduce(t, 'b n d -> b', 'sum')
|
32
|
+
den = mask.sum(dim = -1)
|
33
|
+
|
34
|
+
masked_average = num / den.clamp(min = 1.)
|
35
|
+
return masked_average
|
36
|
+
|
27
37
|
# main classes
|
28
38
|
|
29
39
|
class ContinuousTransformerWrapper(nn.Module):
|
@@ -169,12 +179,15 @@ class ContinuousAutoregressiveWrapper(nn.Module):
|
|
169
179
|
net: ContinuousTransformerWrapper,
|
170
180
|
ignore_index = -100,
|
171
181
|
pad_value = 0,
|
172
|
-
loss_fn = nn.MSELoss(reduction = 'none')
|
182
|
+
loss_fn = nn.MSELoss(reduction = 'none'),
|
183
|
+
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
173
184
|
):
|
174
185
|
super().__init__()
|
175
186
|
self.net = net
|
176
187
|
self.max_seq_len = net.max_seq_len
|
188
|
+
|
177
189
|
self.loss_fn = loss_fn
|
190
|
+
self.equal_loss_weight_batch = equal_loss_weight_batch
|
178
191
|
|
179
192
|
@torch.no_grad()
|
180
193
|
def generate(self, start_tokens, seq_len, **kwargs):
|
@@ -222,6 +235,10 @@ class ContinuousAutoregressiveWrapper(nn.Module):
|
|
222
235
|
|
223
236
|
if exists(mask):
|
224
237
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
225
|
-
|
238
|
+
|
239
|
+
if self.equal_loss_weight_batch:
|
240
|
+
loss = masked_mean(loss, mask)
|
241
|
+
else:
|
242
|
+
loss = loss[mask]
|
226
243
|
|
227
244
|
return loss.mean()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
{x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.9 → x_transformers-1.42.11}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|