x-transformers 1.43.2__py3-none-any.whl → 1.43.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +22 -5
- {x_transformers-1.43.2.dist-info → x_transformers-1.43.5.dist-info}/METADATA +1 -1
- {x_transformers-1.43.2.dist-info → x_transformers-1.43.5.dist-info}/RECORD +6 -6
- {x_transformers-1.43.2.dist-info → x_transformers-1.43.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.43.2.dist-info → x_transformers-1.43.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.43.2.dist-info → x_transformers-1.43.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -38,6 +38,7 @@ class LayerIntermediates:
|
|
38
38
|
attn_z_loss: Tensor | None = None
|
39
39
|
mems: Tensor | None = None
|
40
40
|
memory_tokens: Tensor | None = None
|
41
|
+
logit_entropies: Tensor | None = None
|
41
42
|
|
42
43
|
LinearNoBias = partial(nn.Linear, bias = False)
|
43
44
|
|
@@ -136,6 +137,15 @@ def or_reduce(masks):
|
|
136
137
|
head = head | rest
|
137
138
|
return head
|
138
139
|
|
140
|
+
# entropy
|
141
|
+
|
142
|
+
def calc_entropy(
|
143
|
+
t: Tensor,
|
144
|
+
is_prob = False
|
145
|
+
):
|
146
|
+
prob = t.softmax(dim = -1) if not is_prob else t
|
147
|
+
return -(prob * log(prob)).sum(dim = -1)
|
148
|
+
|
139
149
|
# auxiliary loss helpers
|
140
150
|
|
141
151
|
def calc_z_loss(
|
@@ -2270,16 +2280,16 @@ class AttentionLayers(Module):
|
|
2270
2280
|
if self.need_condition:
|
2271
2281
|
final_norm = maybe(partial)(final_norm, **norm_kwargs)
|
2272
2282
|
|
2273
|
-
if self.resi_dual:
|
2274
|
-
x = x + final_norm(outer_residual)
|
2275
|
-
else:
|
2276
|
-
x = final_norm(x)
|
2277
|
-
|
2278
2283
|
# take care of multistreams if needed, use sum for now
|
2279
2284
|
|
2280
2285
|
if is_multistream:
|
2281
2286
|
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2282
2287
|
|
2288
|
+
if self.resi_dual:
|
2289
|
+
x = x + final_norm(outer_residual)
|
2290
|
+
else:
|
2291
|
+
x = final_norm(x)
|
2292
|
+
|
2283
2293
|
if not return_hiddens:
|
2284
2294
|
return x
|
2285
2295
|
|
@@ -2592,6 +2602,7 @@ class TransformerWrapper(Module):
|
|
2592
2602
|
return_embeddings = False,
|
2593
2603
|
return_logits_and_embeddings = False,
|
2594
2604
|
return_intermediates = False,
|
2605
|
+
return_logit_entropies = False,
|
2595
2606
|
mask = None,
|
2596
2607
|
return_mems = False,
|
2597
2608
|
return_attn = False,
|
@@ -2809,6 +2820,12 @@ class TransformerWrapper(Module):
|
|
2809
2820
|
else:
|
2810
2821
|
out = logits
|
2811
2822
|
|
2823
|
+
# logit entropies
|
2824
|
+
|
2825
|
+
if return_logit_entropies:
|
2826
|
+
intermediates.logit_entropies = calc_entropy(logits)
|
2827
|
+
return_intermediates = True
|
2828
|
+
|
2812
2829
|
# aux loss
|
2813
2830
|
|
2814
2831
|
if return_attn_z_loss:
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=z3RH6jvjcxaAVfZoCS0HWrE0Gy55-eXOKtzRt7rRRIw,100811
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.43.
|
13
|
-
x_transformers-1.43.
|
14
|
-
x_transformers-1.43.
|
15
|
-
x_transformers-1.43.
|
16
|
-
x_transformers-1.43.
|
12
|
+
x_transformers-1.43.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.43.5.dist-info/METADATA,sha256=crd2xA3NbodKbOz9xY4D1j3XDbTmaY9vwkZZJOGoEw4,738
|
14
|
+
x_transformers-1.43.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.43.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.43.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|