x-transformers 2.1.16__py3-none-any.whl → 2.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +30 -23
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.17.dist-info}/METADATA +1 -1
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.17.dist-info}/RECORD +5 -5
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.17.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.17.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
|
|
23
23
|
)
|
24
24
|
|
25
25
|
import einx
|
26
|
-
from einops import rearrange, repeat
|
26
|
+
from einops import rearrange, repeat, pack, unpack
|
27
27
|
from einops.layers.torch import Rearrange
|
28
28
|
|
29
29
|
# helper functions
|
@@ -69,6 +69,8 @@ class BeliefStateWrapper(Module):
|
|
69
69
|
dim = forward_decoder.emb_dim
|
70
70
|
num_tokens = forward_decoder.num_tokens
|
71
71
|
|
72
|
+
self.to_forward_logits = nn.Linear(dim, num_tokens, bias = False)
|
73
|
+
|
72
74
|
# the suffix token
|
73
75
|
|
74
76
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
@@ -126,7 +128,7 @@ class BeliefStateWrapper(Module):
|
|
126
128
|
prompts,
|
127
129
|
seq_len,
|
128
130
|
temperature = 1.25,
|
129
|
-
cache_kv =
|
131
|
+
cache_kv = False,
|
130
132
|
suffix: Tensor | None = None, # the goal conditioning
|
131
133
|
filter_logits_fn = min_p,
|
132
134
|
filter_kwargs = dict(
|
@@ -136,6 +138,8 @@ class BeliefStateWrapper(Module):
|
|
136
138
|
):
|
137
139
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
138
140
|
|
141
|
+
prompts, batch_ps = pack([prompts], '* d')
|
142
|
+
|
139
143
|
batch, orig_seq_len = prompts.shape
|
140
144
|
|
141
145
|
out = prompts
|
@@ -197,14 +201,19 @@ class BeliefStateWrapper(Module):
|
|
197
201
|
|
198
202
|
# concat sample
|
199
203
|
|
200
|
-
out = torch.cat((out, sample), dim
|
204
|
+
out = torch.cat((out, sample), dim = -1)
|
205
|
+
|
206
|
+
out = out[:, orig_seq_len:]
|
201
207
|
|
202
|
-
|
208
|
+
out, = unpack(out, batch_ps, '* n')
|
209
|
+
|
210
|
+
return out
|
203
211
|
|
204
212
|
def forward(
|
205
213
|
self,
|
206
214
|
seq,
|
207
|
-
|
215
|
+
return_loss_only = False,
|
216
|
+
loss_scale = 1.
|
208
217
|
):
|
209
218
|
batch, seq_len, device = *seq.shape, seq.device
|
210
219
|
|
@@ -244,6 +253,7 @@ class BeliefStateWrapper(Module):
|
|
244
253
|
# f - forward, b - backward, i - indices
|
245
254
|
|
246
255
|
fi, bi = fb_pairs.unbind(dim = -1)
|
256
|
+
|
247
257
|
valid_mask = (bi - fi) >= 2
|
248
258
|
|
249
259
|
fb_pairs = fb_pairs[valid_mask]
|
@@ -265,8 +275,9 @@ class BeliefStateWrapper(Module):
|
|
265
275
|
|
266
276
|
labels_fi, labels_bi = (fi + 1), bi
|
267
277
|
|
268
|
-
forward_labels, backward_labels = seq[:,
|
269
|
-
|
278
|
+
forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
|
279
|
+
|
280
|
+
labels = cat((forward_labels, backward_labels), dim = -1)
|
270
281
|
|
271
282
|
# get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
|
272
283
|
|
@@ -281,7 +292,7 @@ class BeliefStateWrapper(Module):
|
|
281
292
|
|
282
293
|
loss = F.cross_entropy(
|
283
294
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
284
|
-
|
295
|
+
labels,
|
285
296
|
reduction = 'none' if self.needs_loss_weight else 'mean'
|
286
297
|
)
|
287
298
|
|
@@ -290,12 +301,12 @@ class BeliefStateWrapper(Module):
|
|
290
301
|
if exists(self.to_terminal_logit):
|
291
302
|
terminal_logits = self.to_terminal_logit(fb_embeds)
|
292
303
|
|
293
|
-
|
294
|
-
|
304
|
+
terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
|
305
|
+
terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
|
295
306
|
|
296
307
|
is_end_loss = F.binary_cross_entropy_with_logits(
|
297
308
|
terminal_logits,
|
298
|
-
|
309
|
+
terminal_labels
|
299
310
|
)
|
300
311
|
|
301
312
|
loss = (
|
@@ -303,6 +314,11 @@ class BeliefStateWrapper(Module):
|
|
303
314
|
is_end_loss * self.pred_terminal_loss_weight
|
304
315
|
)
|
305
316
|
|
317
|
+
# maybe early return loss
|
318
|
+
|
319
|
+
if return_loss_only:
|
320
|
+
return loss
|
321
|
+
|
306
322
|
# maybe loss weighting
|
307
323
|
|
308
324
|
if self.needs_loss_weight:
|
@@ -312,18 +328,9 @@ class BeliefStateWrapper(Module):
|
|
312
328
|
|
313
329
|
# backwards
|
314
330
|
|
315
|
-
|
316
|
-
|
317
|
-
def patched_backward_fn(*args, **kwargs):
|
318
|
-
orig_backward(*args, **kwargs)
|
319
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
320
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
321
|
-
|
322
|
-
# can allow the researcher to call .backward from the outside
|
331
|
+
(loss * loss_scale).backward()
|
323
332
|
|
324
|
-
|
325
|
-
|
326
|
-
else:
|
327
|
-
setattr(loss, 'backward', patched_backward_fn)
|
333
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
334
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
328
335
|
|
329
336
|
return loss
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=7zi_inz75jZVQE4zFVA2H45yR2V0RLTgX_aGO86rN9s,9878
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.17.dist-info/METADATA,sha256=PzCTyZz0yqv6vXVpzZwtWyu6sjR-kKwcgVPAsrn5TnI,87571
|
14
|
+
x_transformers-2.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|