tmnt 0.7.47b20240410__py3-none-any.whl → 0.7.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/estimator.py CHANGED
@@ -18,7 +18,7 @@ import json
18
18
  from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
19
19
  from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
20
20
  from tmnt.modeling import BowVAEModel, CovariateBowVAEModel, SeqBowVED
21
- from tmnt.modeling import GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
21
+ from tmnt.modeling import SelfEmbeddingCrossEntropyLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
22
22
  from tmnt.eval_npmi import EvaluateNPMI
23
23
  from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
24
24
 
@@ -1572,12 +1572,16 @@ class SeqBowEstimator(BaseEstimator):
1572
1572
 
1573
1573
  class SeqBowMetricEstimator(SeqBowEstimator):
1574
1574
 
1575
- def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1,
1575
+ def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, use_teacher_forcing=False,
1576
+ teacher_forcing_right=True,
1576
1577
  use_sdml=False, non_scoring_index=-1, **kwargs):
1577
1578
  super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
1578
- self.loss_function = \
1579
- GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
1580
- else MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
1579
+ if use_teacher_forcing:
1580
+ self.loss_function = SelfEmbeddingCrossEntropyLoss(teacher_right=teacher_forcing_right, metric_loss_temp=metric_loss_temp)
1581
+ else:
1582
+ self.loss_function = \
1583
+ GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
1584
+ else MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
1581
1585
  self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
1582
1586
 
1583
1587
 
tmnt/modeling.py CHANGED
@@ -775,3 +775,37 @@ class MultiNegativeCrossEntropyLoss(_Loss):
775
775
  return self._loss(x1, l1, x2, l2)
776
776
 
777
777
 
778
+ class SelfEmbeddingCrossEntropyLoss(_Loss):
779
+ """
780
+ Inputs:
781
+ - **x1**: Minibatch of data points with shape (batch_size, vector_dim)
782
+ - **x2**: Minibatch of data points with shape (batch_size, vector_dim)
783
+ Each item in x1 is a positive sample for the items with the same label in x2
784
+ That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
785
+ All data points in different rows should be decorrelated
786
+
787
+ Outputs:
788
+ - **loss**: loss tensor with shape (batch_size,).
789
+ """
790
+
791
+ def __init__(self, teacher_right=True, metric_loss_temp=0.1, batch_axis=0, **kwargs):
792
+ super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
793
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
794
+ self.metric_loss_temp = metric_loss_temp
795
+ self.teacher_right = teacher_right
796
+
797
+ def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
798
+ """
799
+ the function computes the kl divergence between the negative distances
800
+ and the smoothed label matrix.
801
+ """
802
+ x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
803
+ x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
804
+ cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1))
805
+ single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) if self.teacher_right else torch.mm(x1_norm, x1_norm.transpose(0,1))
806
+ # multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
807
+ return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
808
+
809
+
810
+ def forward(self, x1, l1, x2, l2):
811
+ return self._loss(x1, l1, x2, l2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.47b20240410
3
+ Version: 0.7.49
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=A0tsM6x61BGhYBV6rAYdryz2NwbR__8EAYj_Q4Z-DCs,18736
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=cRdA3s3_PmbSU36xYc8cfano_rkqEl9j_0FM3eZ8IA8,76953
5
+ tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=372eAVcnI5xcBYRwSO8N0XK_ECWHwRw7KfuIB8uz3RA,33018
8
+ tmnt/modeling.py,sha256=IZLc9SMaqKtUlEaDZXzy9g6ZdJW1GItyFzKPvvxlxzg,34761
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.47b20240410.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.47b20240410.dist-info/METADATA,sha256=VXzhwjgWkC12v6gaiuMYl2pGokmyY8GhhOuPYs5tQog,1452
22
- tmnt-0.7.47b20240410.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.47b20240410.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
- tmnt-0.7.47b20240410.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.47b20240410.dist-info/RECORD,,
20
+ tmnt-0.7.49.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.49.dist-info/METADATA,sha256=6M-Y1ETuXDPGG_ZLNTMVJXqvicWRfJmb6_5GpmicZoo,1443
22
+ tmnt-0.7.49.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.49.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
+ tmnt-0.7.49.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.49.dist-info/RECORD,,