tmnt 0.7.44b20240126__py3-none-any.whl → 0.7.44b20240128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -39,7 +39,8 @@ llm_catalog = {
39
39
  'openai-gpt' : (AutoTokenizer.from_pretrained, OpenAIGPTModel.from_pretrained),
40
40
  'sentence-transformers/all-mpnet-base-v2' : (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
41
41
  'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
42
- 'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
42
+ 'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
43
+ 'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
43
44
  ## add more model options here if desired
44
45
  }
45
46
 
tmnt/estimator.py CHANGED
@@ -1115,7 +1115,7 @@ class SeqBowEstimator(BaseEstimator):
1115
1115
  classifier_dropout = 0.0,
1116
1116
  pure_classifier_objective = False,
1117
1117
  validate_each_epoch = False,
1118
- entropy_loss_coef = 1000.0,
1118
+ entropy_loss_coef = 0.0,
1119
1119
  pool_encoder = True,
1120
1120
  **kwargs):
1121
1121
  super(SeqBowEstimator, self).__init__(*args, **kwargs)
@@ -1527,13 +1527,6 @@ class SeqBowEstimator(BaseEstimator):
1527
1527
  else:
1528
1528
  self._output_status("Epoch [{}]. Objective = {} ==> PPL = {}. NPMI ={}. Redundancy = {}."
1529
1529
  .format(epoch_id, sc_obj, v_res['ppl'], v_res['npmi'], v_res['redundancy']))
1530
- #if self.reporter:
1531
- #if 'accuracy' in v_res:
1532
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1533
- # "redundancy": v_res['redundancy'], "accuracy": v_res['accuracy']})
1534
- #else:
1535
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1536
- # "redundancy": v_res['redundancy']})
1537
1530
  return sc_obj, v_res
1538
1531
 
1539
1532
 
@@ -1610,13 +1603,8 @@ class SeqBowMetricEstimator(SeqBowEstimator):
1610
1603
  bow_batch_b = seqs_b[3].to_dense()
1611
1604
  sums += bow_batch_a.sum(axis=0)
1612
1605
  sums += bow_batch_b.sum(axis=0)
1613
- return sums.cpu().numpy() #def _get_model_bias_initialize(self, train_data):
1614
- # model = self._get_model()
1615
- # tr_bow_matrix = self._get_bow_matrix(train_data)
1616
- #model.initialize_bias_terms(tr_bow_matrix.sum(axis=0))
1617
- # return model
1606
+ return sums.cpu().numpy()
1618
1607
 
1619
-
1620
1608
  def _get_bow_matrix(self, dataloader, cache=False):
1621
1609
  bow_matrix = []
1622
1610
  for _, seqs in enumerate(dataloader):
@@ -1670,10 +1658,5 @@ class SeqBowMetricEstimator(SeqBowEstimator):
1670
1658
  v_res = self.validate(model, dev_data, epoch_id)
1671
1659
  self._output_status("Epoch [{}]. ==> elbo loss = {}; kldiv loss = {}"
1672
1660
  .format(epoch_id, v_res['elbo_ls'], v_res['kl_ls']))
1673
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1674
- # "redundancy": v_res['redundancy']})
1675
- #if self.reporter:
1676
- # self.reporter(epoch=epoch_id+1, objective=v_res['avg_prec'], time_step=time.time(), coherence=0.0,
1677
- # perplexity=0.0, redundancy=0.0)
1678
1661
  return v_res['kl_ls'], v_res
1679
1662
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.44b20240126
3
+ Version: 0.7.44b20240128
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -1,8 +1,8 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
- tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
3
+ tmnt/data_loading.py,sha256=IB7qgoeIY6a4i-YDB7kwWUU3LMvlCGF6_PgzlWDjkc8,18392
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=kQZ42MfOBBZuF0TQVdd9vBlw101ZlXk77mlws2ZvAS4,78014
5
+ tmnt/estimator.py,sha256=cRdA3s3_PmbSU36xYc8cfano_rkqEl9j_0FM3eZ8IA8,76953
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
8
  tmnt/modeling.py,sha256=372eAVcnI5xcBYRwSO8N0XK_ECWHwRw7KfuIB8uz3RA,33018
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.44b20240126.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.44b20240126.dist-info/METADATA,sha256=0MCgY6kov5ji5IrnhU4Eru5DAqcsgdD3_rqFc_9JXyE,1403
22
- tmnt-0.7.44b20240126.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.44b20240126.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- tmnt-0.7.44b20240126.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.44b20240126.dist-info/RECORD,,
20
+ tmnt-0.7.44b20240128.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.44b20240128.dist-info/METADATA,sha256=b1P0wAHx0ISujY5QjZMKW7XjfH_f1pzxOAFLxOQnN3k,1403
22
+ tmnt-0.7.44b20240128.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.44b20240128.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ tmnt-0.7.44b20240128.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.44b20240128.dist-info/RECORD,,