tmnt 0.7.49__py3-none-any.whl → 0.7.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/modeling.py +4 -3
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/METADATA +1 -1
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/RECORD +7 -7
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/LICENSE +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/NOTICE +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/WHEEL +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.50.dist-info}/top_level.txt +0 -0
tmnt/modeling.py
CHANGED
@@ -788,7 +788,7 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
788
788
|
- **loss**: loss tensor with shape (batch_size,).
|
789
789
|
"""
|
790
790
|
|
791
|
-
def __init__(self, teacher_right=True, metric_loss_temp=0
|
791
|
+
def __init__(self, teacher_right=True, metric_loss_temp=1.0, batch_axis=0, **kwargs):
|
792
792
|
super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
|
793
793
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
794
794
|
self.metric_loss_temp = metric_loss_temp
|
@@ -801,8 +801,9 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
801
801
|
"""
|
802
802
|
x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
|
803
803
|
x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
|
804
|
-
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1))
|
805
|
-
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1))
|
804
|
+
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
|
805
|
+
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
|
806
|
+
else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
|
806
807
|
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
|
807
808
|
return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
|
808
809
|
|
@@ -5,7 +5,7 @@ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
|
5
5
|
tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=9aVXSDfiNCM6fzRkvdrvmT1Vv69k_XJ_t12b9I-T_qA,34847
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.50.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.50.dist-info/METADATA,sha256=kVFUVgKEDDFi7uUxamXxAcsKUD4QjPvsUJGtLiCrGAM,1443
|
22
|
+
tmnt-0.7.50.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.50.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
24
|
+
tmnt-0.7.50.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.50.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|