tmnt 0.7.49__py3-none-any.whl → 0.7.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/modeling.py CHANGED
@@ -788,7 +788,7 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
788
788
  - **loss**: loss tensor with shape (batch_size,).
789
789
  """
790
790
 
791
- def __init__(self, teacher_right=True, metric_loss_temp=0.1, batch_axis=0, **kwargs):
791
+ def __init__(self, teacher_right=True, metric_loss_temp=0.5, batch_axis=0, **kwargs):
792
792
  super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
793
793
  self.cross_entropy_loss = nn.CrossEntropyLoss()
794
794
  self.metric_loss_temp = metric_loss_temp
@@ -799,10 +799,14 @@ class SelfEmbeddingCrossEntropyLoss(_Loss):
799
799
  the function computes the kl divergence between the negative distances
800
800
  and the smoothed label matrix.
801
801
  """
802
+ batch_size = l1.size()[0]
802
803
  x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
803
804
  x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
804
- cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1))
805
- single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) if self.teacher_right else torch.mm(x1_norm, x1_norm.transpose(0,1))
805
+ cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
806
+ single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
807
+ else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
808
+ # need to normalize these
809
+ single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
806
810
  # multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
807
811
  return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
808
812
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.49
3
+ Version: 0.7.51
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -5,7 +5,7 @@ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
5
  tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=IZLc9SMaqKtUlEaDZXzy9g6ZdJW1GItyFzKPvvxlxzg,34761
8
+ tmnt/modeling.py,sha256=wKDuUsw2bvsrvJ7LkcnSXAPh8cvUSd8y3Q7eGAf_JeU,35049
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.49.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.49.dist-info/METADATA,sha256=6M-Y1ETuXDPGG_ZLNTMVJXqvicWRfJmb6_5GpmicZoo,1443
22
- tmnt-0.7.49.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.49.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
- tmnt-0.7.49.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.49.dist-info/RECORD,,
20
+ tmnt-0.7.51.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.51.dist-info/METADATA,sha256=2KT4XrKVkPIKkkgeO_iEzHrG_ERSwX2Td1KTdyiKMn8,1443
22
+ tmnt-0.7.51.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.51.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
+ tmnt-0.7.51.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.51.dist-info/RECORD,,
File without changes
File without changes
File without changes