x-transformers 2.1.16__py3-none-any.whl → 2.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
23
23
  )
24
24
 
25
25
  import einx
26
- from einops import rearrange, repeat
26
+ from einops import rearrange, repeat, pack, unpack
27
27
  from einops.layers.torch import Rearrange
28
28
 
29
29
  # helper functions
@@ -126,7 +126,7 @@ class BeliefStateWrapper(Module):
126
126
  prompts,
127
127
  seq_len,
128
128
  temperature = 1.25,
129
- cache_kv = True,
129
+ cache_kv = False,
130
130
  suffix: Tensor | None = None, # the goal conditioning
131
131
  filter_logits_fn = min_p,
132
132
  filter_kwargs = dict(
@@ -136,6 +136,8 @@ class BeliefStateWrapper(Module):
136
136
  ):
137
137
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
138
138
 
139
+ prompts, batch_ps = pack([prompts], '* d')
140
+
139
141
  batch, orig_seq_len = prompts.shape
140
142
 
141
143
  out = prompts
@@ -197,14 +199,19 @@ class BeliefStateWrapper(Module):
197
199
 
198
200
  # concat sample
199
201
 
200
- out = torch.cat((out, sample), dim=-1)
202
+ out = torch.cat((out, sample), dim = -1)
203
+
204
+ out = out[:, orig_seq_len:]
201
205
 
202
- return out[:, orig_seq_len:]
206
+ out, = unpack(out, batch_ps, '* n')
207
+
208
+ return out
203
209
 
204
210
  def forward(
205
211
  self,
206
212
  seq,
207
- backward = True
213
+ return_loss_only = False,
214
+ loss_scale = 1.
208
215
  ):
209
216
  batch, seq_len, device = *seq.shape, seq.device
210
217
 
@@ -244,6 +251,7 @@ class BeliefStateWrapper(Module):
244
251
  # f - forward, b - backward, i - indices
245
252
 
246
253
  fi, bi = fb_pairs.unbind(dim = -1)
254
+
247
255
  valid_mask = (bi - fi) >= 2
248
256
 
249
257
  fb_pairs = fb_pairs[valid_mask]
@@ -265,8 +273,9 @@ class BeliefStateWrapper(Module):
265
273
 
266
274
  labels_fi, labels_bi = (fi + 1), bi
267
275
 
268
- forward_labels, backward_labels = seq[:, fi], seq[:, bi]
269
- labels = stack((forward_labels, backward_labels), dim = -1)
276
+ forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
277
+
278
+ labels = cat((forward_labels, backward_labels), dim = -1)
270
279
 
271
280
  # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
272
281
 
@@ -281,7 +290,7 @@ class BeliefStateWrapper(Module):
281
290
 
282
291
  loss = F.cross_entropy(
283
292
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
284
- rearrange(labels, 'b n fb -> b (fb n)'),
293
+ labels,
285
294
  reduction = 'none' if self.needs_loss_weight else 'mean'
286
295
  )
287
296
 
@@ -290,12 +299,12 @@ class BeliefStateWrapper(Module):
290
299
  if exists(self.to_terminal_logit):
291
300
  terminal_logits = self.to_terminal_logit(fb_embeds)
292
301
 
293
- labels = ((bi - fi) == 2).float() # distance is exactly 2
294
- labels = repeat(labels, 'n -> b n', b = batch)
302
+ terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
303
+ terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
295
304
 
296
305
  is_end_loss = F.binary_cross_entropy_with_logits(
297
306
  terminal_logits,
298
- labels
307
+ terminal_labels
299
308
  )
300
309
 
301
310
  loss = (
@@ -303,6 +312,11 @@ class BeliefStateWrapper(Module):
303
312
  is_end_loss * self.pred_terminal_loss_weight
304
313
  )
305
314
 
315
+ # maybe early return loss
316
+
317
+ if return_loss_only:
318
+ return loss
319
+
306
320
  # maybe loss weighting
307
321
 
308
322
  if self.needs_loss_weight:
@@ -312,18 +326,9 @@ class BeliefStateWrapper(Module):
312
326
 
313
327
  # backwards
314
328
 
315
- orig_backward = getattr(loss, 'backward')
316
-
317
- def patched_backward_fn(*args, **kwargs):
318
- orig_backward(*args, **kwargs)
319
- orig_forward_embeds.backward(forward_embeds.grad)
320
- orig_backward_embeds.backward(backward_embeds.grad)
321
-
322
- # can allow the researcher to call .backward from the outside
329
+ (loss * loss_scale).backward()
323
330
 
324
- if backward:
325
- patched_backward_fn()
326
- else:
327
- setattr(loss, 'backward', patched_backward_fn)
331
+ orig_forward_embeds.backward(forward_embeds.grad)
332
+ orig_backward_embeds.backward(backward_embeds.grad)
328
333
 
329
334
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.16
3
+ Version: 2.1.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=VjUB73yFBWevN6xMc6_1s-Yc58pJv8SDAUUEXwpR-W0,9842
4
+ x_transformers/belief_state_wrapper.py,sha256=MkZqVpwEZJfZGTCB24n2DvGbnFSZefk9LZ-HA5LqGS0,9803
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.16.dist-info/METADATA,sha256=MXaag1fuq1BsAmyh9k8sSRiHpy-jAUQW2Hn1GC53MnQ,87571
14
- x_transformers-2.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.16.dist-info/RECORD,,
13
+ x_transformers-2.1.18.dist-info/METADATA,sha256=bW46vXghSJjACgIgOlf5UYzX9fZz6ryEAoevC0DnE-0,87571
14
+ x_transformers-2.1.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.18.dist-info/RECORD,,