tmnt 0.7.44b20240123__py3-none-any.whl → 0.7.44b20240125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/data_loading.py +2 -1
- tmnt/estimator.py +11 -4
- tmnt/modeling.py +4 -2
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/METADATA +1 -1
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/RECORD +9 -9
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/LICENSE +0 -0
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/NOTICE +0 -0
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/WHEEL +0 -0
- {tmnt-0.7.44b20240123.dist-info → tmnt-0.7.44b20240125.dist-info}/top_level.txt +0 -0
tmnt/data_loading.py
CHANGED
@@ -38,7 +38,8 @@ llm_catalog = {
|
|
38
38
|
'bert-base-uncased' : (AutoTokenizer.from_pretrained, BertModel.from_pretrained),
|
39
39
|
'openai-gpt' : (AutoTokenizer.from_pretrained, OpenAIGPTModel.from_pretrained),
|
40
40
|
'sentence-transformers/all-mpnet-base-v2' : (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
41
|
-
'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
|
41
|
+
'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
42
|
+
'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
|
42
43
|
## add more model options here if desired
|
43
44
|
}
|
44
45
|
|
tmnt/estimator.py
CHANGED
@@ -1116,6 +1116,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1116
1116
|
pure_classifier_objective = False,
|
1117
1117
|
validate_each_epoch = False,
|
1118
1118
|
entropy_loss_coef = 1000.0,
|
1119
|
+
pool_encoder = True,
|
1119
1120
|
**kwargs):
|
1120
1121
|
super(SeqBowEstimator, self).__init__(*args, **kwargs)
|
1121
1122
|
self.pure_classifier_objective = pure_classifier_objective
|
@@ -1135,6 +1136,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1135
1136
|
self.decoder_lr = decoder_lr
|
1136
1137
|
self._bow_matrix = None
|
1137
1138
|
self.entropy_loss_coef = entropy_loss_coef
|
1139
|
+
self.pool_encoder = pool_encoder
|
1138
1140
|
|
1139
1141
|
|
1140
1142
|
@classmethod
|
@@ -1217,7 +1219,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1217
1219
|
def _get_model(self):
|
1218
1220
|
llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
|
1219
1221
|
model = SeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
|
1220
|
-
vocab_size = len(self.vocabulary), use_pooling =
|
1222
|
+
vocab_size = len(self.vocabulary), use_pooling = self.pool_encoder,
|
1221
1223
|
entropy_loss_coef=self.entropy_loss_coef,
|
1222
1224
|
dropout=self.classifier_dropout)
|
1223
1225
|
return model
|
@@ -1443,6 +1445,9 @@ class SeqBowEstimator(BaseEstimator):
|
|
1443
1445
|
if class_ls is not None:
|
1444
1446
|
loss_details['class_loss'] += float(class_ls.mean())
|
1445
1447
|
|
1448
|
+
sc_obj = None
|
1449
|
+
v_res = None
|
1450
|
+
|
1446
1451
|
for epoch_id in range(self.epochs):
|
1447
1452
|
if self.metric is not None:
|
1448
1453
|
self.metric.reset()
|
@@ -1570,10 +1575,12 @@ class SeqBowEstimator(BaseEstimator):
|
|
1570
1575
|
|
1571
1576
|
class SeqBowMetricEstimator(SeqBowEstimator):
|
1572
1577
|
|
1573
|
-
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1,
|
1578
|
+
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1,
|
1579
|
+
use_sdml=False, non_scoring_index=-1, **kwargs):
|
1574
1580
|
super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
|
1575
|
-
|
1576
|
-
|
1581
|
+
self.loss_function = \
|
1582
|
+
GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
|
1583
|
+
else MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
|
1577
1584
|
self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
|
1578
1585
|
|
1579
1586
|
|
tmnt/modeling.py
CHANGED
@@ -473,7 +473,7 @@ class BaseSeqBowVED(BaseVAE):
|
|
473
473
|
vocab_size=2000,
|
474
474
|
kld=0.1,
|
475
475
|
device='cpu',
|
476
|
-
use_pooling=
|
476
|
+
use_pooling=True,
|
477
477
|
entropy_loss_coef=1000.0,
|
478
478
|
redundancy_reg_penalty=0.0, pre_trained_embedding = None):
|
479
479
|
super(BaseSeqBowVED, self).__init__(device=device, vocab_size=vocab_size)
|
@@ -493,7 +493,9 @@ class BaseSeqBowVED(BaseVAE):
|
|
493
493
|
if pre_trained_embedding is not None:
|
494
494
|
self.embedding = nn.Linear(len(pre_trained_embedding.idx_to_vec),
|
495
495
|
pre_trained_embedding.idx_to_vec[0].size, bias=False)
|
496
|
-
self.apply(self._init_weights)
|
496
|
+
#self.apply(self._init_weights)
|
497
|
+
self.latent_distribution.apply(self._init_weights)
|
498
|
+
self.decoder.apply(self._init_weights)
|
497
499
|
|
498
500
|
def _init_weights(self, module):
|
499
501
|
if isinstance(module, torch.nn.Linear):
|
@@ -1,11 +1,11 @@
|
|
1
1
|
tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
|
-
tmnt/data_loading.py,sha256=
|
3
|
+
tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
|
4
4
|
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
|
-
tmnt/estimator.py,sha256=
|
5
|
+
tmnt/estimator.py,sha256=IIXjtKB09qUqL_lDiDbhd5IVsW7hLuCHo82fF27xp64,77942
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=Q-CSN0oaftf6RhM3Y3zKk4xw1Wd_WeZmPexZy8nk2Nw,32947
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.44b20240125.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.44b20240125.dist-info/METADATA,sha256=0duXA_NTiacN4bKgC10fnqMdPeOfVEHqy9EDz7EqquU,1403
|
22
|
+
tmnt-0.7.44b20240125.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.44b20240125.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
24
|
+
tmnt-0.7.44b20240125.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.44b20240125.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|