tmnt 0.7.50__py3-none-any.whl → 0.7.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/modeling.py +4 -1
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/METADATA +1 -1
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/RECORD +7 -7
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/LICENSE +0 -0
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/NOTICE +0 -0
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/WHEEL +0 -0
- {tmnt-0.7.50.dist-info → tmnt-0.7.51.dist-info}/top_level.txt +0 -0
tmnt/modeling.py
CHANGED
@@ -788,7 +788,7 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
788
788
|
- **loss**: loss tensor with shape (batch_size,).
|
789
789
|
"""
|
790
790
|
|
791
|
-
def __init__(self, teacher_right=True, metric_loss_temp=
|
791
|
+
def __init__(self, teacher_right=True, metric_loss_temp=0.5, batch_axis=0, **kwargs):
|
792
792
|
super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
|
793
793
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
794
794
|
self.metric_loss_temp = metric_loss_temp
|
@@ -799,11 +799,14 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
799
799
|
the function computes the kl divergence between the negative distances
|
800
800
|
and the smoothed label matrix.
|
801
801
|
"""
|
802
|
+
batch_size = l1.size()[0]
|
802
803
|
x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
|
803
804
|
x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
|
804
805
|
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
|
805
806
|
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
|
806
807
|
else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
|
808
|
+
# need to normalize these
|
809
|
+
single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
|
807
810
|
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
|
808
811
|
return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
|
809
812
|
|
@@ -5,7 +5,7 @@ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
|
5
5
|
tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=wKDuUsw2bvsrvJ7LkcnSXAPh8cvUSd8y3Q7eGAf_JeU,35049
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.51.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.51.dist-info/METADATA,sha256=2KT4XrKVkPIKkkgeO_iEzHrG_ERSwX2Td1KTdyiKMn8,1443
|
22
|
+
tmnt-0.7.51.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.51.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
24
|
+
tmnt-0.7.51.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.51.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|