x-transformers 2.1.5__py3-none-any.whl → 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
14
14
 
15
15
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
16
16
  from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
17
+ from x_transformers.belief_state_wrapper import BeliefStateWrapper
17
18
 
18
19
  from x_transformers.continuous import (
19
20
  ContinuousTransformerWrapper,
@@ -6,7 +6,7 @@
6
6
  import torch
7
7
  from torch.autograd import Function
8
8
  from torch.nn import Module, ModuleList
9
- from torch import nn, cat, stack, arange, cartesian_prod
9
+ from torch import nn, cat, stack, tensor, arange, cartesian_prod
10
10
  import torch.nn.functional as F
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -36,7 +36,8 @@ class BeliefStateWrapper(Module):
36
36
  self,
37
37
  forward_decoder: TransformerWrapper,
38
38
  backward_decoder: TransformerWrapper,
39
- train_frac_forward_backward_pairs: float = 1.
39
+ train_frac_forward_backward_pairs: float = 1.,
40
+ backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
40
41
  ):
41
42
  super().__init__()
42
43
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
@@ -70,9 +71,17 @@ class BeliefStateWrapper(Module):
70
71
  self.train_frac_fb_pairs = train_frac_forward_backward_pairs
71
72
  self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
72
73
 
74
+ # loss weighting
75
+
76
+ self.backward_ar_loss_weight = backward_ar_loss_weight
77
+ self.needs_loss_weight = backward_ar_loss_weight != 1.
78
+
79
+ self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
80
+
73
81
  def forward(
74
82
  self,
75
- seq
83
+ seq,
84
+ backward = True
76
85
  ):
77
86
  batch, seq_len, device = *seq.shape, seq.device
78
87
 
@@ -149,14 +158,31 @@ class BeliefStateWrapper(Module):
149
158
 
150
159
  fb_loss = F.cross_entropy(
151
160
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
152
- rearrange(labels, 'b n fb -> b (fb n)')
161
+ rearrange(labels, 'b n fb -> b (fb n)'),
162
+ reduction = 'none' if self.needs_loss_weight else 'mean'
153
163
  )
154
164
 
165
+ # maybe loss weighting
166
+
167
+ if self.needs_loss_weight:
168
+ fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n', fb = 2)
169
+ fb_loss = einx.multiply('b fb n, fb', fb_loss, self.loss_weights)
170
+ fb_loss = fb_loss.mean()
171
+
155
172
  # backwards
156
173
 
157
- fb_loss.backward()
174
+ orig_backward = getattr(fb_loss, 'backward')
175
+
176
+ def patched_backward_fn(*args, **kwargs):
177
+ orig_backward(*args, **kwargs)
178
+ orig_forward_embeds.backward(forward_embeds.grad)
179
+ orig_backward_embeds.backward(backward_embeds.grad)
180
+
181
+ # can allow the researcher to call .backward from the outside
158
182
 
159
- orig_forward_embeds.backward(forward_embeds.grad)
160
- orig_backward_embeds.backward(backward_embeds.grad)
183
+ if backward:
184
+ patched_backward_fn()
185
+ else:
186
+ setattr(fb_loss, 'backward', patched_backward_fn)
161
187
 
162
188
  return fb_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.5
3
+ Version: 2.1.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
- x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
1
+ x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state.py,sha256=GfYDeDqmhldozECgFsJ9zhd6O5NMvdYA5OwueVs8SB4,4742
4
+ x_transformers/belief_state_wrapper.py,sha256=WoagJe_cHguWAWzhEAmeDD_TzUsQcMQkgP2Mt88wrAg,5827
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.5.dist-info/METADATA,sha256=-nQpm1eBGBXkLEXGiQK06NaIWa13CEnBujBNfTzvnJ8,87570
14
- x_transformers-2.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.5.dist-info/RECORD,,
13
+ x_transformers-2.1.7.dist-info/METADATA,sha256=D7enhdRYdhTddt4kGSS-rSsvtHgtFaLYQH-cVuYxUVM,87570
14
+ x_transformers-2.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.7.dist-info/RECORD,,