x-transformers 1.42.9__py3-none-any.whl → 1.42.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -48,7 +48,7 @@ def align_right(t, lens, pad_id = 0):
48
48
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
49
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
50
 
51
- t = F.pad(t, (max_pad_len, 0), value = 0)
51
+ t = F.pad(t, (max_pad_len, 0), value = pad_id)
52
52
  offset = max_pad_len - pad_lens
53
53
 
54
54
  aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
@@ -2,7 +2,8 @@ import torch
2
2
  from torch import nn
3
3
  import torch.nn.functional as F
4
4
 
5
- from einops import pack, repeat, unpack
5
+ import einx
6
+ from einops import reduce, pack, repeat, unpack
6
7
 
7
8
  from x_transformers.x_transformers import (
8
9
  AttentionLayers,
@@ -24,6 +25,15 @@ def default(val, d):
24
25
  return val
25
26
  return d() if callable(d) else d
26
27
 
28
+ def masked_mean(t, mask):
29
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
30
+
31
+ num = reduce(t, 'b n d -> b', 'sum')
32
+ den = mask.sum(dim = -1)
33
+
34
+ masked_average = num / den.clamp(min = 1.)
35
+ return masked_average
36
+
27
37
  # main classes
28
38
 
29
39
  class ContinuousTransformerWrapper(nn.Module):
@@ -169,12 +179,15 @@ class ContinuousAutoregressiveWrapper(nn.Module):
169
179
  net: ContinuousTransformerWrapper,
170
180
  ignore_index = -100,
171
181
  pad_value = 0,
172
- loss_fn = nn.MSELoss(reduction = 'none')
182
+ loss_fn = nn.MSELoss(reduction = 'none'),
183
+ equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
173
184
  ):
174
185
  super().__init__()
175
186
  self.net = net
176
187
  self.max_seq_len = net.max_seq_len
188
+
177
189
  self.loss_fn = loss_fn
190
+ self.equal_loss_weight_batch = equal_loss_weight_batch
178
191
 
179
192
  @torch.no_grad()
180
193
  def generate(self, start_tokens, seq_len, **kwargs):
@@ -222,6 +235,10 @@ class ContinuousAutoregressiveWrapper(nn.Module):
222
235
 
223
236
  if exists(mask):
224
237
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
225
- loss = loss[mask]
238
+
239
+ if self.equal_loss_weight_batch:
240
+ loss = masked_mean(loss, mask)
241
+ else:
242
+ loss = loss[mask]
226
243
 
227
244
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.9
3
+ Version: 1.42.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
- x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
- x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
3
+ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
+ x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -9,8 +9,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
9
9
  x_transformers/x_transformers.py,sha256=VxdA44EYQhVH1Rp7wreJ83I2e0Ea7VN_bFRE-iDXOI8,93833
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.9.dist-info/METADATA,sha256=k9r-D0b0xnf8gwE-SwwgybnfQpoRwiY0wthOn66xc6Y,689
14
- x_transformers-1.42.9.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.9.dist-info/RECORD,,
12
+ x_transformers-1.42.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.11.dist-info/METADATA,sha256=0dlrRj5RehRfEhgK7M4ESmaNHuthe912XQiC7Hsim_8,690
14
+ x_transformers-1.42.11.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.11.dist-info/RECORD,,