tmnt 0.7.44b20240123__py3-none-any.whl → 0.7.44b20240125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -38,7 +38,8 @@ llm_catalog = {
38
38
  'bert-base-uncased' : (AutoTokenizer.from_pretrained, BertModel.from_pretrained),
39
39
  'openai-gpt' : (AutoTokenizer.from_pretrained, OpenAIGPTModel.from_pretrained),
40
40
  'sentence-transformers/all-mpnet-base-v2' : (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
41
- 'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
41
+ 'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
42
+ 'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
42
43
  ## add more model options here if desired
43
44
  }
44
45
 
tmnt/estimator.py CHANGED
@@ -1116,6 +1116,7 @@ class SeqBowEstimator(BaseEstimator):
1116
1116
  pure_classifier_objective = False,
1117
1117
  validate_each_epoch = False,
1118
1118
  entropy_loss_coef = 1000.0,
1119
+ pool_encoder = True,
1119
1120
  **kwargs):
1120
1121
  super(SeqBowEstimator, self).__init__(*args, **kwargs)
1121
1122
  self.pure_classifier_objective = pure_classifier_objective
@@ -1135,6 +1136,7 @@ class SeqBowEstimator(BaseEstimator):
1135
1136
  self.decoder_lr = decoder_lr
1136
1137
  self._bow_matrix = None
1137
1138
  self.entropy_loss_coef = entropy_loss_coef
1139
+ self.pool_encoder = pool_encoder
1138
1140
 
1139
1141
 
1140
1142
  @classmethod
@@ -1217,7 +1219,7 @@ class SeqBowEstimator(BaseEstimator):
1217
1219
  def _get_model(self):
1218
1220
  llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
1219
1221
  model = SeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
1220
- vocab_size = len(self.vocabulary), use_pooling = (self.llm_model_name.startswith("sentence-transformers")),
1222
+ vocab_size = len(self.vocabulary), use_pooling = self.pool_encoder,
1221
1223
  entropy_loss_coef=self.entropy_loss_coef,
1222
1224
  dropout=self.classifier_dropout)
1223
1225
  return model
@@ -1443,6 +1445,9 @@ class SeqBowEstimator(BaseEstimator):
1443
1445
  if class_ls is not None:
1444
1446
  loss_details['class_loss'] += float(class_ls.mean())
1445
1447
 
1448
+ sc_obj = None
1449
+ v_res = None
1450
+
1446
1451
  for epoch_id in range(self.epochs):
1447
1452
  if self.metric is not None:
1448
1453
  self.metric.reset()
@@ -1570,10 +1575,12 @@ class SeqBowEstimator(BaseEstimator):
1570
1575
 
1571
1576
  class SeqBowMetricEstimator(SeqBowEstimator):
1572
1577
 
1573
- def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, non_scoring_index=-1, **kwargs):
1578
+ def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1,
1579
+ use_sdml=False, non_scoring_index=-1, **kwargs):
1574
1580
  super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
1575
- #self.loss_function = GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index)
1576
- self.loss_function = MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
1581
+ self.loss_function = \
1582
+ GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index) if use_sdml \
1583
+ else MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
1577
1584
  self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
1578
1585
 
1579
1586
 
tmnt/modeling.py CHANGED
@@ -473,7 +473,7 @@ class BaseSeqBowVED(BaseVAE):
473
473
  vocab_size=2000,
474
474
  kld=0.1,
475
475
  device='cpu',
476
- use_pooling=False,
476
+ use_pooling=True,
477
477
  entropy_loss_coef=1000.0,
478
478
  redundancy_reg_penalty=0.0, pre_trained_embedding = None):
479
479
  super(BaseSeqBowVED, self).__init__(device=device, vocab_size=vocab_size)
@@ -493,7 +493,9 @@ class BaseSeqBowVED(BaseVAE):
493
493
  if pre_trained_embedding is not None:
494
494
  self.embedding = nn.Linear(len(pre_trained_embedding.idx_to_vec),
495
495
  pre_trained_embedding.idx_to_vec[0].size, bias=False)
496
- self.apply(self._init_weights)
496
+ #self.apply(self._init_weights)
497
+ self.latent_distribution.apply(self._init_weights)
498
+ self.decoder.apply(self._init_weights)
497
499
 
498
500
  def _init_weights(self, module):
499
501
  if isinstance(module, torch.nn.Linear):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.44b20240123
3
+ Version: 0.7.44b20240125
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -1,11 +1,11 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
- tmnt/data_loading.py,sha256=_NpAwmpeFBoQp7xtWOLb6i3WS271JoSJqDx9BMrXtKM,18207
3
+ tmnt/data_loading.py,sha256=B47kfq5nrpw2bHYT2qEv2tpCLT7EFwqD7ZDjsoBto_Q,18303
4
4
  tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=xk4QATqqD8ukxtraOQ6BvSJrdqGTQvX52fNdcgfQ3w8,77801
5
+ tmnt/estimator.py,sha256=IIXjtKB09qUqL_lDiDbhd5IVsW7hLuCHo82fF27xp64,77942
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
8
- tmnt/modeling.py,sha256=NTgjTqvi3sUsEfQa8Kq8lGW3vST905B8OkNhQmNwpwA,32841
8
+ tmnt/modeling.py,sha256=Q-CSN0oaftf6RhM3Y3zKk4xw1Wd_WeZmPexZy8nk2Nw,32947
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.44b20240123.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.44b20240123.dist-info/METADATA,sha256=BvdBJQro8PU8RZPCKCXaK7-Ui30wTQDAfK-hNQT0qlE,1403
22
- tmnt-0.7.44b20240123.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.44b20240123.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- tmnt-0.7.44b20240123.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.44b20240123.dist-info/RECORD,,
20
+ tmnt-0.7.44b20240125.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.44b20240125.dist-info/METADATA,sha256=0duXA_NTiacN4bKgC10fnqMdPeOfVEHqy9EDz7EqquU,1403
22
+ tmnt-0.7.44b20240125.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.44b20240125.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ tmnt-0.7.44b20240125.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.44b20240125.dist-info/RECORD,,