tmnt 0.7.51b20240412__py3-none-any.whl → 0.7.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/estimator.py CHANGED
@@ -18,7 +18,7 @@ import json
18
18
  from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
19
19
  from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
20
20
  from tmnt.modeling import BowVAEModel, CovariateBowVAEModel, SeqBowVED
21
- from tmnt.modeling import SelfEmbeddingCrossEntropyLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
21
+ from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
22
22
  from tmnt.eval_npmi import EvaluateNPMI
23
23
  from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
24
24
 
@@ -1573,11 +1573,11 @@ class SeqBowEstimator(BaseEstimator):
1573
1573
  class SeqBowMetricEstimator(SeqBowEstimator):
1574
1574
 
1575
1575
  def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, use_teacher_forcing=False,
1576
- teacher_forcing_right=True,
1576
+ teacher_forcing_mode='rand',
1577
1577
  use_sdml=False, non_scoring_index=-1, **kwargs):
1578
1578
  super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
1579
1579
  if use_teacher_forcing:
1580
- self.loss_function = SelfEmbeddingCrossEntropyLoss(teacher_right=teacher_forcing_right, metric_loss_temp=metric_loss_temp)
1580
+ self.loss_function = CrossBatchCosineSimilarityLoss(teacher_mode = teacher_forcing_mode)
1581
1581
  else:
1582
1582
  self.loss_function = \
1583
1583
  GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
tmnt/modeling.py CHANGED
@@ -14,6 +14,7 @@ from tmnt.distribution import BaseDistribution
14
14
  from torch import nn
15
15
  from torch.nn.modules.loss import _Loss
16
16
  import torch
17
+ from torch import Tensor
17
18
  from torch.distributions.categorical import Categorical
18
19
 
19
20
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
@@ -775,41 +776,39 @@ class MultiNegativeCrossEntropyLoss(_Loss):
775
776
  return self._loss(x1, l1, x2, l2)
776
777
 
777
778
 
778
- class SelfEmbeddingCrossEntropyLoss(_Loss):
779
+ class CrossBatchCosineSimilarityLoss(_Loss):
779
780
  """
780
781
  Inputs:
781
782
  - **x1**: Minibatch of data points with shape (batch_size, vector_dim)
782
783
  - **x2**: Minibatch of data points with shape (batch_size, vector_dim)
783
784
  Each item in x1 is a positive sample for the items with the same label in x2
784
- That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
785
- All data points in different rows should be decorrelated
786
785
 
787
786
  Outputs:
788
787
  - **loss**: loss tensor with shape (batch_size,).
789
788
  """
790
789
 
791
- def __init__(self, teacher_right=True, metric_loss_temp=0.5, batch_axis=0, **kwargs):
792
- super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
793
- self.cross_entropy_loss = nn.CrossEntropyLoss()
794
- self.metric_loss_temp = metric_loss_temp
795
- self.teacher_right = teacher_right
790
+ def __init__(self, teacher_mode='rand', batch_axis=0, **kwargs):
791
+ super(CrossBatchCosineSimilarityLoss, self).__init__(batch_axis, **kwargs)
792
+ self.loss_fn = nn.MSELoss()
793
+ self.teacher_mode = teacher_mode
794
+
795
+ def cosine_sim(self, a: Tensor, b: Tensor) -> Tensor:
796
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
797
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
798
+ return torch.mm(a_norm, b_norm.transpose(0, 1))
796
799
 
797
800
  def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
798
- """
799
- the function computes the kl divergence between the negative distances
800
- and the smoothed label matrix.
801
- """
802
- batch_size = l1.size()[0]
803
- x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
804
- x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
805
- cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp
806
- single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) / self.metric_loss_temp if self.teacher_right \
807
- else torch.mm(x1_norm, x1_norm.transpose(0,1)) / self.metric_loss_temp
808
- # need to normalize these
809
- single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
810
- # multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
811
- return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
812
-
801
+ scores = self.cosine_sim(x1,x2)
802
+ if self.teacher_mode == 'right':
803
+ labels = self.cosine_sim(x2,x2).detach()
804
+ elif self.teacher_mode == 'left':
805
+ labels = self.cosine_sim(x1,x1).detach()
806
+ else:
807
+ if np.random.randint(2):
808
+ labels = self.cosine_sim(x2,x2).detach()
809
+ else:
810
+ labels = self.cosine_sim(x1,x1).detach()
811
+ return self.loss_fn(scores, labels)
813
812
 
814
813
  def forward(self, x1, l1, x2, l2):
815
814
  return self._loss(x1, l1, x2, l2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.51b20240412
3
+ Version: 0.7.52
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=A0tsM6x61BGhYBV6rAYdryz2NwbR__8EAYj_Q4Z-DCs,18736
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
5
+ tmnt/estimator.py,sha256=MERanBwrbYqUcHC872qXCIjUoqjlTKnYjOCBu6mxo90,77217
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=wKDuUsw2bvsrvJ7LkcnSXAPh8cvUSd8y3Q7eGAf_JeU,35049
8
+ tmnt/modeling.py,sha256=UJvwQU2ujmY3hUBmUuTWOsZ5AcUFcw-kQhYFK5pICTY,34549
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.51b20240412.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.51b20240412.dist-info/METADATA,sha256=x2c3Q8FLiFfUfbE68ih_lMP7u0_i5M5RHBLisRbXRVw,1452
22
- tmnt-0.7.51b20240412.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.51b20240412.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
- tmnt-0.7.51b20240412.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.51b20240412.dist-info/RECORD,,
20
+ tmnt-0.7.52.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.52.dist-info/METADATA,sha256=8jzdkE7tv6P_5OAMS7_pp8_iPyAtwDVhmQ9o5Eo2Zfo,1443
22
+ tmnt-0.7.52.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.52.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
+ tmnt-0.7.52.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.52.dist-info/RECORD,,