tmnt 0.7.50__py3-none-any.whl → 0.7.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/modeling.py CHANGED
@@ -788,7 +788,7 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
788
788
  - **loss**: loss tensor with shape (batch_size,).
789
789
  """
790
790
 
791
- def __init__(self, teacher_right=True, metric_loss_temp=1.0, batch_axis=0, **kwargs):
791
+ def __init__(self, teacher_right=True, metric_loss_temp=0.5, batch_axis=0, **kwargs):
792
792
  super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
793
793
  self.cross_entropy_loss = nn.CrossEntropyLoss()
794
794
  self.metric_loss_temp = metric_loss_temp
@@ -799,11 +799,14 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
799
799
  the function computes the kl divergence between the negative distances
800
800
  and the smoothed label matrix.
801
801
  """
802
+ batch_size = l1.size()[0]
802
803
  x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
803
804
  x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
804
805
  cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
805
806
  single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
806
807
  else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
808
+ # need to normalize these
809
+ single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
807
810
  # multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
808
811
  return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
809
812
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.50
3
+ Version: 0.7.51
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -5,7 +5,7 @@ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
5
  tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=9aVXSDfiNCM6fzRkvdrvmT1Vv69k_XJ_t12b9I-T_qA,34847
8
+ tmnt/modeling.py,sha256=wKDuUsw2bvsrvJ7LkcnSXAPh8cvUSd8y3Q7eGAf_JeU,35049
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.50.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.50.dist-info/METADATA,sha256=kVFUVgKEDDFi7uUxamXxAcsKUD4QjPvsUJGtLiCrGAM,1443
22
- tmnt-0.7.50.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.50.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
- tmnt-0.7.50.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.50.dist-info/RECORD,,
20
+ tmnt-0.7.51.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.51.dist-info/METADATA,sha256=2KT4XrKVkPIKkkgeO_iEzHrG_ERSwX2Td1KTdyiKMn8,1443
22
+ tmnt-0.7.51.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.51.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
+ tmnt-0.7.51.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.51.dist-info/RECORD,,
File without changes
File without changes
File without changes