x-transformers 1.37.3__py3-none-any.whl → 1.37.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/nonautoregressive_wrapper.py +3 -1
- x_transformers/xl_autoregressive_wrapper.py +6 -4
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.4.dist-info}/METADATA +1 -1
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.4.dist-info}/RECORD +7 -7
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.4.dist-info}/top_level.txt +0 -0
@@ -309,9 +309,11 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
309
309
|
with context():
|
310
310
|
logits = self.net(masked, **kwargs)
|
311
311
|
|
312
|
+
loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
|
313
|
+
|
312
314
|
# cross entropy loss
|
313
315
|
|
314
|
-
loss =
|
316
|
+
loss = loss_fn(
|
315
317
|
logits[mask],
|
316
318
|
orig_seq[mask]
|
317
319
|
)
|
@@ -40,7 +40,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
40
40
|
eos_token = None,
|
41
41
|
temperature = 1.,
|
42
42
|
filter_logits_fn = top_k,
|
43
|
-
|
43
|
+
filter_kwargs: dict = dict(),
|
44
44
|
mems = None,
|
45
45
|
**kwargs
|
46
46
|
):
|
@@ -88,7 +88,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
88
88
|
mems = cache.mems
|
89
89
|
|
90
90
|
logits = logits[:, -1]
|
91
|
-
filtered_logits = filter_logits_fn(logits,
|
91
|
+
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
92
92
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
93
93
|
|
94
94
|
sample = torch.multinomial(probs, 1)
|
@@ -131,7 +131,9 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
131
131
|
|
132
132
|
split_x = x.split(max_seq_len, dim = -1)
|
133
133
|
split_labels = labels.split(max_seq_len, dim = -1)
|
134
|
-
loss_weights = tuple(
|
134
|
+
loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x)
|
135
|
+
|
136
|
+
loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
|
135
137
|
|
136
138
|
# go through each chunk and derive weighted losses
|
137
139
|
|
@@ -146,7 +148,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
146
148
|
**kwargs
|
147
149
|
)
|
148
150
|
|
149
|
-
loss =
|
151
|
+
loss = loss_fn(
|
150
152
|
rearrange(logits, 'b n c -> b c n'),
|
151
153
|
chunk_labels,
|
152
154
|
ignore_index = ignore_index
|
@@ -4,12 +4,12 @@ x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJy
|
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=
|
7
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=FDSNVq7wkq4c9muHKngaVPvpHfISfqA8MI7fZZ0KcOY,10503
|
8
8
|
x_transformers/x_transformers.py,sha256=gOJBZzOJMu5RkIsxw9TZtde4Sx--D18yX8LjrYIsPbE,83677
|
9
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=
|
9
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=mwygaOG6ECJGrzG1ViA6Xxu9m_WSYOzPUDzUhkt4m0U,4188
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.4.dist-info/METADATA,sha256=H1oc8hkvPpEoS4ioSOn1wVSHjcTJ3OJcFdJuJuuWdl4,661
|
13
|
+
x_transformers-1.37.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|