tmnt 0.7.47b20240410__py3-none-any.whl → 0.7.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/estimator.py +9 -5
- tmnt/modeling.py +34 -0
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/METADATA +1 -1
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/RECORD +8 -8
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/LICENSE +0 -0
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/NOTICE +0 -0
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/WHEEL +0 -0
- {tmnt-0.7.47b20240410.dist-info → tmnt-0.7.49.dist-info}/top_level.txt +0 -0
tmnt/estimator.py
CHANGED
@@ -18,7 +18,7 @@ import json
|
|
18
18
|
from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
|
19
19
|
from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
|
20
20
|
from tmnt.modeling import BowVAEModel, CovariateBowVAEModel, SeqBowVED
|
21
|
-
from tmnt.modeling import GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
|
21
|
+
from tmnt.modeling import SelfEmbeddingCrossEntropyLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
|
22
22
|
from tmnt.eval_npmi import EvaluateNPMI
|
23
23
|
from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
|
24
24
|
|
@@ -1572,12 +1572,16 @@ class SeqBowEstimator(BaseEstimator):
|
|
1572
1572
|
|
1573
1573
|
class SeqBowMetricEstimator(SeqBowEstimator):
|
1574
1574
|
|
1575
|
-
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1,
|
1575
|
+
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, use_teacher_forcing=False,
|
1576
|
+
teacher_forcing_right=True,
|
1576
1577
|
use_sdml=False, non_scoring_index=-1, **kwargs):
|
1577
1578
|
super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1579
|
+
if use_teacher_forcing:
|
1580
|
+
self.loss_function = SelfEmbeddingCrossEntropyLoss(teacher_right=teacher_forcing_right, metric_loss_temp=metric_loss_temp)
|
1581
|
+
else:
|
1582
|
+
self.loss_function = \
|
1583
|
+
GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
|
1584
|
+
else MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
|
1581
1585
|
self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
|
1582
1586
|
|
1583
1587
|
|
tmnt/modeling.py
CHANGED
@@ -775,3 +775,37 @@ class MultiNegativeCrossEntropyLoss(_Loss):
|
|
775
775
|
return self._loss(x1, l1, x2, l2)
|
776
776
|
|
777
777
|
|
778
|
+
class SelfEmbeddingCrossEntropyLoss(_Loss):
|
779
|
+
"""
|
780
|
+
Inputs:
|
781
|
+
- **x1**: Minibatch of data points with shape (batch_size, vector_dim)
|
782
|
+
- **x2**: Minibatch of data points with shape (batch_size, vector_dim)
|
783
|
+
Each item in x1 is a positive sample for the items with the same label in x2
|
784
|
+
That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
|
785
|
+
All data points in different rows should be decorrelated
|
786
|
+
|
787
|
+
Outputs:
|
788
|
+
- **loss**: loss tensor with shape (batch_size,).
|
789
|
+
"""
|
790
|
+
|
791
|
+
def __init__(self, teacher_right=True, metric_loss_temp=0.1, batch_axis=0, **kwargs):
|
792
|
+
super(SelfEmbeddingCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
|
793
|
+
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
794
|
+
self.metric_loss_temp = metric_loss_temp
|
795
|
+
self.teacher_right = teacher_right
|
796
|
+
|
797
|
+
def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
|
798
|
+
"""
|
799
|
+
the function computes the kl divergence between the negative distances
|
800
|
+
and the smoothed label matrix.
|
801
|
+
"""
|
802
|
+
x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
|
803
|
+
x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
|
804
|
+
cross_side_distances = torch.mm(x1_norm, x2_norm.transpose(0,1))
|
805
|
+
single_side_distances = torch.mm(x2_norm, x2_norm.transpose(0,1)) if self.teacher_right else torch.mm(x1_norm, x1_norm.transpose(0,1))
|
806
|
+
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
|
807
|
+
return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
|
808
|
+
|
809
|
+
|
810
|
+
def forward(self, x1, l1, x2, l2):
|
811
|
+
return self._loss(x1, l1, x2, l2)
|
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
3
|
tmnt/data_loading.py,sha256=A0tsM6x61BGhYBV6rAYdryz2NwbR__8EAYj_Q4Z-DCs,18736
|
4
4
|
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
|
-
tmnt/estimator.py,sha256=
|
5
|
+
tmnt/estimator.py,sha256=i37NVmUseDuEWfk4cwZcShsRrbINLbtrqRzDAPmJUwU,77249
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=IZLc9SMaqKtUlEaDZXzy9g6ZdJW1GItyFzKPvvxlxzg,34761
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.49.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.49.dist-info/METADATA,sha256=6M-Y1ETuXDPGG_ZLNTMVJXqvicWRfJmb6_5GpmicZoo,1443
|
22
|
+
tmnt-0.7.49.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.49.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
24
|
+
tmnt-0.7.49.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.49.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|