tmnt 0.7.44b20240125__py3-none-any.whl → 0.7.44b20240127__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/estimator.py +8 -4
- tmnt/modeling.py +2 -1
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/METADATA +1 -1
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/RECORD +8 -8
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/LICENSE +0 -0
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/NOTICE +0 -0
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/WHEEL +0 -0
- {tmnt-0.7.44b20240125.dist-info → tmnt-0.7.44b20240127.dist-info}/top_level.txt +0 -0
tmnt/estimator.py
CHANGED
@@ -1405,7 +1405,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1405
1405
|
"params": [
|
1406
1406
|
p for n, p in model.llm.named_parameters() if (n in decay_parameters and p.requires_grad)
|
1407
1407
|
],
|
1408
|
-
"weight_decay": 1e-
|
1408
|
+
"weight_decay": 1e-3,
|
1409
1409
|
},
|
1410
1410
|
{ "params": [
|
1411
1411
|
p for n, p in model.llm.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
@@ -1452,6 +1452,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1452
1452
|
if self.metric is not None:
|
1453
1453
|
self.metric.reset()
|
1454
1454
|
model.train()
|
1455
|
+
model.llm.train()
|
1455
1456
|
|
1456
1457
|
for (batch_id, (data, aux_batch)) in enumerate(joint_loader):
|
1457
1458
|
# data_batch is either a 2-tuple of: (labeled, unlabeled)
|
@@ -1468,11 +1469,14 @@ class SeqBowEstimator(BaseEstimator):
|
|
1468
1469
|
update_loss_details(total_ls, elbo_ls, red_ls, label_ls)
|
1469
1470
|
if aux_batch is not None:
|
1470
1471
|
update_loss_details(total_ls_2, elbo_ls_2, red_ls_2, None)
|
1472
|
+
|
1473
|
+
#debug
|
1471
1474
|
|
1472
1475
|
if not accumulate or (batch_id + 1) % accumulate == 0:
|
1473
|
-
torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
|
1474
|
-
|
1476
|
+
#torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
|
1477
|
+
optimizer.step()
|
1475
1478
|
dec_optimizer.step()
|
1479
|
+
lr_scheduler.step()
|
1476
1480
|
model.zero_grad()
|
1477
1481
|
step_num += 1
|
1478
1482
|
if (batch_id + 1) % (self.log_interval) == 0:
|
@@ -1593,7 +1597,7 @@ class SeqBowMetricEstimator(SeqBowEstimator):
|
|
1593
1597
|
def _get_model(self):
|
1594
1598
|
llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
|
1595
1599
|
model = MetricSeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
|
1596
|
-
vocab_size = len(self.vocabulary), use_pooling=
|
1600
|
+
vocab_size = len(self.vocabulary), use_pooling=self.pool_encoder,
|
1597
1601
|
dropout=self.classifier_dropout, entropy_loss_coef=self.entropy_loss_coef)
|
1598
1602
|
return model
|
1599
1603
|
|
tmnt/modeling.py
CHANGED
@@ -595,10 +595,11 @@ class MetricSeqBowVED(BaseSeqBowVED):
|
|
595
595
|
elbo = elbo1 + elbo2
|
596
596
|
rec_loss = rec_loss1 + rec_loss2
|
597
597
|
KL_loss = KL_loss1 + KL_loss2
|
598
|
-
z_mu1 = self.latent_distribution.get_mu_encoding(
|
598
|
+
z_mu1 = self.latent_distribution.get_mu_encoding(enc2)
|
599
599
|
z_mu2 = self.latent_distribution.get_mu_encoding(enc2)
|
600
600
|
redundancy_loss = entropy_loss1 + entropy_loss2 #self.get_redundancy_penalty()
|
601
601
|
return elbo, rec_loss, KL_loss, redundancy_loss, z_mu1, z_mu2
|
602
|
+
#return elbo, rec_loss, KL_loss, redundancy_loss, enc1, enc2
|
602
603
|
|
603
604
|
|
604
605
|
class GeneralizedSDMLLoss(_Loss):
|
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
3
|
tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
|
4
4
|
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
|
-
tmnt/estimator.py,sha256=
|
5
|
+
tmnt/estimator.py,sha256=kQZ42MfOBBZuF0TQVdd9vBlw101ZlXk77mlws2ZvAS4,78014
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
8
|
-
tmnt/modeling.py,sha256
|
8
|
+
tmnt/modeling.py,sha256=-fvmbT-KXr8luhELnCAOyZ-DUbTUd65cKRNRaH49EKI,33016
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.44b20240127.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.44b20240127.dist-info/METADATA,sha256=RNb_SRd6cyvKGKSJT1NKTDdjjVVUfhDXqRuFIxmy2dE,1403
|
22
|
+
tmnt-0.7.44b20240127.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.44b20240127.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
24
|
+
tmnt-0.7.44b20240127.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.44b20240127.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|