tmnt 0.7.51b20240412__py3-none-any.whl → 0.7.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/estimator.py +3 -3
- tmnt/modeling.py +22 -23
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/METADATA +1 -1
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/RECORD +8 -8
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/LICENSE +0 -0
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/NOTICE +0 -0
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/WHEEL +0 -0
- {tmnt-0.7.51b20240412.dist-info → tmnt-0.7.52.dist-info}/top_level.txt +0 -0
tmnt/estimator.py
CHANGED
@@ -18,7 +18,7 @@ import json
|
|
18
18
|
from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
|
19
19
|
from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
|
20
20
|
from tmnt.modeling import BowVAEModel, CovariateBowVAEModel, SeqBowVED
|
21
|
-
from tmnt.modeling import
|
21
|
+
from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
|
22
22
|
from tmnt.eval_npmi import EvaluateNPMI
|
23
23
|
from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
|
24
24
|
|
@@ -1573,11 +1573,11 @@ class SeqBowEstimator(BaseEstimator):
|
|
1573
1573
|
class SeqBowMetricEstimator(SeqBowEstimator):
|
1574
1574
|
|
1575
1575
|
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, use_teacher_forcing=False,
|
1576
|
-
|
1576
|
+
teacher_forcing_mode='rand',
|
1577
1577
|
use_sdml=False, non_scoring_index=-1, **kwargs):
|
1578
1578
|
super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
|
1579
1579
|
if use_teacher_forcing:
|
1580
|
-
self.loss_function =
|
1580
|
+
self.loss_function = CrossBatchCosineSimilarityLoss(teacher_mode = teacher_forcing_mode)
|
1581
1581
|
else:
|
1582
1582
|
self.loss_function = \
|
1583
1583
|
GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
|
tmnt/modeling.py
CHANGED
@@ -14,6 +14,7 @@ from tmnt.distribution import BaseDistribution
|
|
14
14
|
from torch import nn
|
15
15
|
from torch.nn.modules.loss import _Loss
|
16
16
|
import torch
|
17
|
+
from torch import Tensor
|
17
18
|
from torch.distributions.categorical import Categorical
|
18
19
|
|
19
20
|
from typing import List, Tuple, Dict, Optional, Union, NoReturn
|
@@ -775,41 +776,39 @@ class MultiNegativeCrossEntropyLoss(_Loss):
|
|
775
776
|
return self._loss(x1, l1, x2, l2)
|
776
777
|
|
777
778
|
|
778
|
-
class
|
779
|
+
class CrossBatchCosineSimilarityLoss(_Loss):
|
779
780
|
"""
|
780
781
|
Inputs:
|
781
782
|
- **x1**: Minibatch of data points with shape (batch_size, vector_dim)
|
782
783
|
- **x2**: Minibatch of data points with shape (batch_size, vector_dim)
|
783
784
|
Each item in x1 is a positive sample for the items with the same label in x2
|
784
|
-
That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
|
785
|
-
All data points in different rows should be decorrelated
|
786
785
|
|
787
786
|
Outputs:
|
788
787
|
- **loss**: loss tensor with shape (batch_size,).
|
789
788
|
"""
|
790
789
|
|
791
|
-
def __init__(self,
|
792
|
-
super(
|
793
|
-
self.
|
794
|
-
self.
|
795
|
-
|
790
|
+
def __init__(self, teacher_mode='rand', batch_axis=0, **kwargs):
|
791
|
+
super(CrossBatchCosineSimilarityLoss, self).__init__(batch_axis, **kwargs)
|
792
|
+
self.loss_fn = nn.MSELoss()
|
793
|
+
self.teacher_mode = teacher_mode
|
794
|
+
|
795
|
+
def cosine_sim(self, a: Tensor, b: Tensor) -> Tensor:
|
796
|
+
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
|
797
|
+
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
|
798
|
+
return torch.mm(a_norm, b_norm.transpose(0, 1))
|
796
799
|
|
797
800
|
def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
single_side_distances = single_side_distances / single_side_distances.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
|
810
|
-
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
|
811
|
-
return self.cross_entropy_loss(cross_side_distances, single_side_distances.to(single_side_distances.device))
|
812
|
-
|
801
|
+
scores = self.cosine_sim(x1,x2)
|
802
|
+
if self.teacher_mode == 'right':
|
803
|
+
labels = self.cosine_sim(x2,x2).detach()
|
804
|
+
elif self.teacher_mode == 'left':
|
805
|
+
labels = self.cosine_sim(x1,x1).detach()
|
806
|
+
else:
|
807
|
+
if np.random.randint(2):
|
808
|
+
labels = self.cosine_sim(x2,x2).detach()
|
809
|
+
else:
|
810
|
+
labels = self.cosine_sim(x1,x1).detach()
|
811
|
+
return self.loss_fn(scores, labels)
|
813
812
|
|
814
813
|
def forward(self, x1, l1, x2, l2):
|
815
814
|
return self._loss(x1, l1, x2, l2)
|
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
3
|
tmnt/data_loading.py,sha256=A0tsM6x61BGhYBV6rAYdryz2NwbR__8EAYj_Q4Z-DCs,18736
|
4
4
|
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
|
-
tmnt/estimator.py,sha256=
|
5
|
+
tmnt/estimator.py,sha256=MERanBwrbYqUcHC872qXCIjUoqjlTKnYjOCBu6mxo90,77217
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=UJvwQU2ujmY3hUBmUuTWOsZ5AcUFcw-kQhYFK5pICTY,34549
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.52.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.52.dist-info/METADATA,sha256=8jzdkE7tv6P_5OAMS7_pp8_iPyAtwDVhmQ9o5Eo2Zfo,1443
|
22
|
+
tmnt-0.7.52.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.52.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
24
|
+
tmnt-0.7.52.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.52.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|