tmnt 0.7.44b20240127__py3-none-any.whl → 0.7.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -39,7 +39,8 @@ llm_catalog = {
39
39
  'openai-gpt' : (AutoTokenizer.from_pretrained, OpenAIGPTModel.from_pretrained),
40
40
  'sentence-transformers/all-mpnet-base-v2' : (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
41
41
  'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
42
- 'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
42
+ 'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
43
+ 'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
43
44
  ## add more model options here if desired
44
45
  }
45
46
 
tmnt/estimator.py CHANGED
@@ -1115,7 +1115,7 @@ class SeqBowEstimator(BaseEstimator):
1115
1115
  classifier_dropout = 0.0,
1116
1116
  pure_classifier_objective = False,
1117
1117
  validate_each_epoch = False,
1118
- entropy_loss_coef = 1000.0,
1118
+ entropy_loss_coef = 0.0,
1119
1119
  pool_encoder = True,
1120
1120
  **kwargs):
1121
1121
  super(SeqBowEstimator, self).__init__(*args, **kwargs)
@@ -1527,13 +1527,6 @@ class SeqBowEstimator(BaseEstimator):
1527
1527
  else:
1528
1528
  self._output_status("Epoch [{}]. Objective = {} ==> PPL = {}. NPMI ={}. Redundancy = {}."
1529
1529
  .format(epoch_id, sc_obj, v_res['ppl'], v_res['npmi'], v_res['redundancy']))
1530
- #if self.reporter:
1531
- #if 'accuracy' in v_res:
1532
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1533
- # "redundancy": v_res['redundancy'], "accuracy": v_res['accuracy']})
1534
- #else:
1535
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1536
- # "redundancy": v_res['redundancy']})
1537
1530
  return sc_obj, v_res
1538
1531
 
1539
1532
 
@@ -1610,13 +1603,8 @@ class SeqBowMetricEstimator(SeqBowEstimator):
1610
1603
  bow_batch_b = seqs_b[3].to_dense()
1611
1604
  sums += bow_batch_a.sum(axis=0)
1612
1605
  sums += bow_batch_b.sum(axis=0)
1613
- return sums.cpu().numpy() #def _get_model_bias_initialize(self, train_data):
1614
- # model = self._get_model()
1615
- # tr_bow_matrix = self._get_bow_matrix(train_data)
1616
- #model.initialize_bias_terms(tr_bow_matrix.sum(axis=0))
1617
- # return model
1606
+ return sums.cpu().numpy()
1618
1607
 
1619
-
1620
1608
  def _get_bow_matrix(self, dataloader, cache=False):
1621
1609
  bow_matrix = []
1622
1610
  for _, seqs in enumerate(dataloader):
@@ -1670,10 +1658,5 @@ class SeqBowMetricEstimator(SeqBowEstimator):
1670
1658
  v_res = self.validate(model, dev_data, epoch_id)
1671
1659
  self._output_status("Epoch [{}]. ==> elbo loss = {}; kldiv loss = {}"
1672
1660
  .format(epoch_id, v_res['elbo_ls'], v_res['kl_ls']))
1673
- #session.report({"objective": sc_obj, "coherence": v_res['npmi'], "perplexity": v_res['ppl'],
1674
- # "redundancy": v_res['redundancy']})
1675
- #if self.reporter:
1676
- # self.reporter(epoch=epoch_id+1, objective=v_res['avg_prec'], time_step=time.time(), coherence=0.0,
1677
- # perplexity=0.0, redundancy=0.0)
1678
1661
  return v_res['kl_ls'], v_res
1679
1662
 
tmnt/modeling.py CHANGED
@@ -595,11 +595,11 @@ class MetricSeqBowVED(BaseSeqBowVED):
595
595
  elbo = elbo1 + elbo2
596
596
  rec_loss = rec_loss1 + rec_loss2
597
597
  KL_loss = KL_loss1 + KL_loss2
598
- z_mu1 = self.latent_distribution.get_mu_encoding(enc2)
599
- z_mu2 = self.latent_distribution.get_mu_encoding(enc2)
598
+ #z_mu1 = self.latent_distribution.get_mu_encoding(enc2)
599
+ #z_mu2 = self.latent_distribution.get_mu_encoding(enc2)
600
600
  redundancy_loss = entropy_loss1 + entropy_loss2 #self.get_redundancy_penalty()
601
- return elbo, rec_loss, KL_loss, redundancy_loss, z_mu1, z_mu2
602
- #return elbo, rec_loss, KL_loss, redundancy_loss, enc1, enc2
601
+ #return elbo, rec_loss, KL_loss, redundancy_loss, z_mu1, z_mu2
602
+ return elbo, rec_loss, KL_loss, redundancy_loss, enc1, enc2
603
603
 
604
604
 
605
605
  class GeneralizedSDMLLoss(_Loss):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.44b20240127
3
+ Version: 0.7.46
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -9,7 +9,7 @@ License: Apache
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Classifier: License :: OSI Approved :: Apache Software License
11
11
  Classifier: Operating System :: OS Independent
12
- Requires-Python: >=3.8
12
+ Requires-Python: >=3.10
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
15
  License-File: NOTICE
@@ -27,7 +27,9 @@ Requires-Dist: pyOpenSSL ==18.0.0
27
27
  Requires-Dist: PySocks ==1.6.8
28
28
  Requires-Dist: sacremoses >=0.0.38
29
29
  Requires-Dist: sentence-splitter ==1.4
30
- Requires-Dist: umap-learn ==0.4.6
30
+ Requires-Dist: umap-learn[plot] >=0.5.5
31
+ Requires-Dist: numba
32
+ Requires-Dist: scipy
31
33
  Requires-Dist: tabulate >=0.8.7
32
34
  Requires-Dist: torch >=2.1.2
33
35
  Requires-Dist: torchtext >=0.13.0
@@ -35,7 +37,7 @@ Requires-Dist: torchtext >=0.13.0
35
37
  The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
36
38
  topic models as neural network-based variational auto-encoders.
37
39
 
38
- Current stable version is: 0.7.44
40
+ Current stable version is: 0.7.46
39
41
 
40
42
  Documentation can be found here: https://tmnt.readthedocs.io/en/stable/
41
43
 
@@ -1,11 +1,11 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
- tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
3
+ tmnt/data_loading.py,sha256=IB7qgoeIY6a4i-YDB7kwWUU3LMvlCGF6_PgzlWDjkc8,18392
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=kQZ42MfOBBZuF0TQVdd9vBlw101ZlXk77mlws2ZvAS4,78014
5
+ tmnt/estimator.py,sha256=cRdA3s3_PmbSU36xYc8cfano_rkqEl9j_0FM3eZ8IA8,76953
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=-fvmbT-KXr8luhELnCAOyZ-DUbTUd65cKRNRaH49EKI,33016
8
+ tmnt/modeling.py,sha256=372eAVcnI5xcBYRwSO8N0XK_ECWHwRw7KfuIB8uz3RA,33018
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.44b20240127.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.44b20240127.dist-info/METADATA,sha256=RNb_SRd6cyvKGKSJT1NKTDdjjVVUfhDXqRuFIxmy2dE,1403
22
- tmnt-0.7.44b20240127.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.44b20240127.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- tmnt-0.7.44b20240127.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.44b20240127.dist-info/RECORD,,
20
+ tmnt-0.7.46.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.46.dist-info/METADATA,sha256=KLktIuJoTOtPvY1uML9pgNJwRE-Rxact3yLk092gw7I,1443
22
+ tmnt-0.7.46.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.46.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ tmnt-0.7.46.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.46.dist-info/RECORD,,