tmnt 0.7.49__py3-none-any.whl → 0.7.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/modeling.py +7 -3
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/METADATA +1 -1
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/RECORD +7 -7
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/LICENSE +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/NOTICE +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/WHEEL +0 -0
- {tmnt-0.7.49.dist-info → tmnt-0.7.51.dist-info}/top_level.txt +0 -0
tmnt/modeling.py
CHANGED
@@ -788,7 +788,7 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
788
788
|
- **loss**: loss tensor with shape (batch_size,).
|
789
789
|
"""
|
790
790
|
|
791
|
-
def __init__(self, teacher_right=True, metric_loss_temp=0.
|
791
|
+
def __init__(self, teacher_right=True, metric_loss_temp=0.5, batch_axis=0, **kwargs):
|
792
792
|
super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
|
793
793
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
794
794
|
self.metric_loss_temp = metric_loss_temp
|
@@ -799,10 +799,14 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
|
|
799
799
|
the function computes the kl divergence between the negative distances
|
800
800
|
and the smoothed label matrix.
|
801
801
|
"""
|
802
|
+
batch_size = l1.size()[0]
|
802
803
|
x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
|
803
804
|
x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
|
804
|
-
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1))
|
805
|
-
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1))
|
805
|
+
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
|
806
|
+
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
|
807
|
+
else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
|
808
|
+
# need to normalize these
|
809
|
+
single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
|
806
810
|
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
|
807
811
|
return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
|
808
812
|
|
@@ -5,7 +5,7 @@ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
|
5
5
|
tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=wKDuUsw2bvsrvJ7LkcnSXAPh8cvUSd8y3Q7eGAf_JeU,35049
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.51.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.51.dist-info/METADATA,sha256=2KT4XrKVkPIKkkgeO_iEzHrG_ERSwX2Td1KTdyiKMn8,1443
|
22
|
+
tmnt-0.7.51.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.51.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
24
|
+
tmnt-0.7.51.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.51.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|