x-transformers 1.43.4__tar.gz → 1.43.5__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.43.4/x_transformers.egg-info → x_transformers-1.43.5}/PKG-INFO +1 -1
- {x_transformers-1.43.4 → x_transformers-1.43.5}/setup.py +1 -1
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/x_transformers.py +17 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.43.4 → x_transformers-1.43.5}/LICENSE +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/README.md +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/setup.cfg +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/tests/test_x_transformers.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/__init__.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/attend.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/continuous.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/dpo.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers/xval.py +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -38,6 +38,7 @@ class LayerIntermediates:
|
|
38
38
|
attn_z_loss: Tensor | None = None
|
39
39
|
mems: Tensor | None = None
|
40
40
|
memory_tokens: Tensor | None = None
|
41
|
+
logit_entropies: Tensor | None = None
|
41
42
|
|
42
43
|
LinearNoBias = partial(nn.Linear, bias = False)
|
43
44
|
|
@@ -136,6 +137,15 @@ def or_reduce(masks):
|
|
136
137
|
head = head | rest
|
137
138
|
return head
|
138
139
|
|
140
|
+
# entropy
|
141
|
+
|
142
|
+
def calc_entropy(
|
143
|
+
t: Tensor,
|
144
|
+
is_prob = False
|
145
|
+
):
|
146
|
+
prob = t.softmax(dim = -1) if not is_prob else t
|
147
|
+
return -(prob * log(prob)).sum(dim = -1)
|
148
|
+
|
139
149
|
# auxiliary loss helpers
|
140
150
|
|
141
151
|
def calc_z_loss(
|
@@ -2592,6 +2602,7 @@ class TransformerWrapper(Module):
|
|
2592
2602
|
return_embeddings = False,
|
2593
2603
|
return_logits_and_embeddings = False,
|
2594
2604
|
return_intermediates = False,
|
2605
|
+
return_logit_entropies = False,
|
2595
2606
|
mask = None,
|
2596
2607
|
return_mems = False,
|
2597
2608
|
return_attn = False,
|
@@ -2809,6 +2820,12 @@ class TransformerWrapper(Module):
|
|
2809
2820
|
else:
|
2810
2821
|
out = logits
|
2811
2822
|
|
2823
|
+
# logit entropies
|
2824
|
+
|
2825
|
+
if return_logit_entropies:
|
2826
|
+
intermediates.logit_entropies = calc_entropy(logits)
|
2827
|
+
return_intermediates = True
|
2828
|
+
|
2812
2829
|
# aux loss
|
2813
2830
|
|
2814
2831
|
if return_attn_z_loss:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.43.4 → x_transformers-1.43.5}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|