tmnt 0.7.44b20240125__py3-none-any.whl → 0.7.44b20240127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/estimator.py CHANGED
@@ -1405,7 +1405,7 @@ class SeqBowEstimator(BaseEstimator):
1405
1405
  "params": [
1406
1406
  p for n, p in model.llm.named_parameters() if (n in decay_parameters and p.requires_grad)
1407
1407
  ],
1408
- "weight_decay": 1e-5,
1408
+ "weight_decay": 1e-3,
1409
1409
  },
1410
1410
  { "params": [
1411
1411
  p for n, p in model.llm.named_parameters() if (n not in decay_parameters and p.requires_grad)
@@ -1452,6 +1452,7 @@ class SeqBowEstimator(BaseEstimator):
1452
1452
  if self.metric is not None:
1453
1453
  self.metric.reset()
1454
1454
  model.train()
1455
+ model.llm.train()
1455
1456
 
1456
1457
  for (batch_id, (data, aux_batch)) in enumerate(joint_loader):
1457
1458
  # data_batch is either a 2-tuple of: (labeled, unlabeled)
@@ -1468,11 +1469,14 @@ class SeqBowEstimator(BaseEstimator):
1468
1469
  update_loss_details(total_ls, elbo_ls, red_ls, label_ls)
1469
1470
  if aux_batch is not None:
1470
1471
  update_loss_details(total_ls_2, elbo_ls_2, red_ls_2, None)
1472
+
1473
+ #debug
1471
1474
 
1472
1475
  if not accumulate or (batch_id + 1) % accumulate == 0:
1473
- torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1474
- lr_scheduler.step()
1476
+ #torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1477
+ optimizer.step()
1475
1478
  dec_optimizer.step()
1479
+ lr_scheduler.step()
1476
1480
  model.zero_grad()
1477
1481
  step_num += 1
1478
1482
  if (batch_id + 1) % (self.log_interval) == 0:
@@ -1593,7 +1597,7 @@ class SeqBowMetricEstimator(SeqBowEstimator):
1593
1597
  def _get_model(self):
1594
1598
  llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
1595
1599
  model = MetricSeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
1596
- vocab_size = len(self.vocabulary), use_pooling=(self.llm_model_name.startswith("sentence-transformers")),
1600
+ vocab_size = len(self.vocabulary), use_pooling=self.pool_encoder,
1597
1601
  dropout=self.classifier_dropout, entropy_loss_coef=self.entropy_loss_coef)
1598
1602
  return model
1599
1603
 
tmnt/modeling.py CHANGED
@@ -595,10 +595,11 @@ class MetricSeqBowVED(BaseSeqBowVED):
595
595
  elbo = elbo1 + elbo2
596
596
  rec_loss = rec_loss1 + rec_loss2
597
597
  KL_loss = KL_loss1 + KL_loss2
598
- z_mu1 = self.latent_distribution.get_mu_encoding(enc1)
598
+ z_mu1 = self.latent_distribution.get_mu_encoding(enc2)
599
599
  z_mu2 = self.latent_distribution.get_mu_encoding(enc2)
600
600
  redundancy_loss = entropy_loss1 + entropy_loss2 #self.get_redundancy_penalty()
601
601
  return elbo, rec_loss, KL_loss, redundancy_loss, z_mu1, z_mu2
602
+ #return elbo, rec_loss, KL_loss, redundancy_loss, enc1, enc2
602
603
 
603
604
 
604
605
  class GeneralizedSDMLLoss(_Loss):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.44b20240125
3
+ Version: 0.7.44b20240127
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -2,10 +2,10 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=IIXjtKB09qUqL_lDiDbhd5IVsW7hLuCHo82fF27xp64,77942
5
+ tmnt/estimator.py,sha256=kQZ42MfOBBZuF0TQVdd9vBlw101ZlXk77mlws2ZvAS4,78014
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=Q-CSN0oaftf6RhM3Y3zKk4xw1Wd_WeZmPexZy8nk2Nw,32947
8
+ tmnt/modeling.py,sha256=-fvmbT-KXr8luhELnCAOyZ-DUbTUd65cKRNRaH49EKI,33016
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.44b20240125.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.44b20240125.dist-info/METADATA,sha256=0duXA_NTiacN4bKgC10fnqMdPeOfVEHqy9EDz7EqquU,1403
22
- tmnt-0.7.44b20240125.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.44b20240125.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- tmnt-0.7.44b20240125.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.44b20240125.dist-info/RECORD,,
20
+ tmnt-0.7.44b20240127.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.44b20240127.dist-info/METADATA,sha256=RNb_SRd6cyvKGKSJT1NKTDdjjVVUfhDXqRuFIxmy2dE,1403
22
+ tmnt-0.7.44b20240127.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.44b20240127.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ tmnt-0.7.44b20240127.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.44b20240127.dist-info/RECORD,,