x-transformers 1.43.2__py3-none-any.whl → 1.43.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -38,6 +38,7 @@ class LayerIntermediates:
38
38
  attn_z_loss: Tensor | None = None
39
39
  mems: Tensor | None = None
40
40
  memory_tokens: Tensor | None = None
41
+ logit_entropies: Tensor | None = None
41
42
 
42
43
  LinearNoBias = partial(nn.Linear, bias = False)
43
44
 
@@ -136,6 +137,15 @@ def or_reduce(masks):
136
137
  head = head | rest
137
138
  return head
138
139
 
140
+ # entropy
141
+
142
+ def calc_entropy(
143
+ t: Tensor,
144
+ is_prob = False
145
+ ):
146
+ prob = t.softmax(dim = -1) if not is_prob else t
147
+ return -(prob * log(prob)).sum(dim = -1)
148
+
139
149
  # auxiliary loss helpers
140
150
 
141
151
  def calc_z_loss(
@@ -2270,16 +2280,16 @@ class AttentionLayers(Module):
2270
2280
  if self.need_condition:
2271
2281
  final_norm = maybe(partial)(final_norm, **norm_kwargs)
2272
2282
 
2273
- if self.resi_dual:
2274
- x = x + final_norm(outer_residual)
2275
- else:
2276
- x = final_norm(x)
2277
-
2278
2283
  # take care of multistreams if needed, use sum for now
2279
2284
 
2280
2285
  if is_multistream:
2281
2286
  x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2282
2287
 
2288
+ if self.resi_dual:
2289
+ x = x + final_norm(outer_residual)
2290
+ else:
2291
+ x = final_norm(x)
2292
+
2283
2293
  if not return_hiddens:
2284
2294
  return x
2285
2295
 
@@ -2592,6 +2602,7 @@ class TransformerWrapper(Module):
2592
2602
  return_embeddings = False,
2593
2603
  return_logits_and_embeddings = False,
2594
2604
  return_intermediates = False,
2605
+ return_logit_entropies = False,
2595
2606
  mask = None,
2596
2607
  return_mems = False,
2597
2608
  return_attn = False,
@@ -2809,6 +2820,12 @@ class TransformerWrapper(Module):
2809
2820
  else:
2810
2821
  out = logits
2811
2822
 
2823
+ # logit entropies
2824
+
2825
+ if return_logit_entropies:
2826
+ intermediates.logit_entropies = calc_entropy(logits)
2827
+ return_intermediates = True
2828
+
2812
2829
  # aux loss
2813
2830
 
2814
2831
  if return_attn_z_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.2
3
+ Version: 1.43.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=SsmQacZJ95oTBCcu33m42-_xhBlnoO4G9MkGRQNF_rI,100392
9
+ x_transformers/x_transformers.py,sha256=z3RH6jvjcxaAVfZoCS0HWrE0Gy55-eXOKtzRt7rRRIw,100811
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.2.dist-info/METADATA,sha256=ilN5k-uCKnN1LFgV_mTKt6zX2f1NxrUF7-nqwR_yJ4c,738
14
- x_transformers-1.43.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.2.dist-info/RECORD,,
12
+ x_transformers-1.43.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.43.5.dist-info/METADATA,sha256=crd2xA3NbodKbOz9xY4D1j3XDbTmaY9vwkZZJOGoEw4,738
14
+ x_transformers-1.43.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.43.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.43.5.dist-info/RECORD,,