x-transformers 1.43.4__py3-none-any.whl → 1.43.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -38,6 +38,7 @@ class LayerIntermediates:
38
38
  attn_z_loss: Tensor | None = None
39
39
  mems: Tensor | None = None
40
40
  memory_tokens: Tensor | None = None
41
+ logit_entropies: Tensor | None = None
41
42
 
42
43
  LinearNoBias = partial(nn.Linear, bias = False)
43
44
 
@@ -136,6 +137,15 @@ def or_reduce(masks):
136
137
  head = head | rest
137
138
  return head
138
139
 
140
+ # entropy
141
+
142
+ def calc_entropy(
143
+ t: Tensor,
144
+ is_prob = False
145
+ ):
146
+ prob = t.softmax(dim = -1) if not is_prob else t
147
+ return -(prob * log(prob)).sum(dim = -1)
148
+
139
149
  # auxiliary loss helpers
140
150
 
141
151
  def calc_z_loss(
@@ -2592,6 +2602,7 @@ class TransformerWrapper(Module):
2592
2602
  return_embeddings = False,
2593
2603
  return_logits_and_embeddings = False,
2594
2604
  return_intermediates = False,
2605
+ return_logit_entropies = False,
2595
2606
  mask = None,
2596
2607
  return_mems = False,
2597
2608
  return_attn = False,
@@ -2809,6 +2820,12 @@ class TransformerWrapper(Module):
2809
2820
  else:
2810
2821
  out = logits
2811
2822
 
2823
+ # logit entropies
2824
+
2825
+ if return_logit_entropies:
2826
+ intermediates.logit_entropies = calc_entropy(logits)
2827
+ return_intermediates = True
2828
+
2812
2829
  # aux loss
2813
2830
 
2814
2831
  if return_attn_z_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.4
3
+ Version: 1.43.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=gn0vRtwbjBA67T-Z8tkU-k3Xte0PaMTxZlmzdK8UsFw,100392
9
+ x_transformers/x_transformers.py,sha256=z3RH6jvjcxaAVfZoCS0HWrE0Gy55-eXOKtzRt7rRRIw,100811
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.4.dist-info/METADATA,sha256=Nlj9DcMqnMxJH-xR4Dwd4aU1U-UQIUshpQaMDcggVes,738
14
- x_transformers-1.43.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.4.dist-info/RECORD,,
12
+ x_transformers-1.43.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.43.5.dist-info/METADATA,sha256=crd2xA3NbodKbOz9xY4D1j3XDbTmaY9vwkZZJOGoEw4,738
14
+ x_transformers-1.43.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.43.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.43.5.dist-info/RECORD,,