tmnt 0.7.0b20230910__py3-none-any.whl → 0.7.0b20230912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -92,7 +92,7 @@ def get_llm_paired_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_ma
92
92
 
93
93
  class StratifiedPairedLLMLoader():
94
94
 
95
- def __init__(self, data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, device='cpu'):
95
+ def __init__(self, data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, num_batches=0, device='cpu'):
96
96
  self.data_a = data_a
97
97
  self.data_b = data_b
98
98
  self.bow_vectorizer = bow_vectorizer
@@ -102,7 +102,7 @@ class StratifiedPairedLLMLoader():
102
102
  self.max_len_a = max_len_a
103
103
  self.max_len_b = max_len_b
104
104
  self.device = device
105
- self.num_batches = max(len(data_a), len(data_b)) // batch_size
105
+ self.num_batches = num_batches or max(len(data_a), len(data_b)) // batch_size
106
106
  self.stratified_sampler = StratifiedDualBatchSampler(np.array([label_map[l] for (l,_) in data_a]),
107
107
  np.array([label_map[l] for (l,_) in data_b]),
108
108
  batch_size,
@@ -419,22 +419,22 @@ class StratifiedDualBatchSampler:
419
419
  self.shuffle = shuffle
420
420
  self.batch_size = batch_size
421
421
  self.num_batches = num_batches
422
- counts_a = Counter(y_a)
423
- counts_b = Counter(y_b)
422
+ self.counts_a = Counter(y_a)
423
+ self.counts_b = Counter(y_b)
424
424
  self.class_weights_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
425
425
  self.class_weights_b = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
426
- for k in counts_a:
427
- self.class_weights_a[k] = counts_a[k] / len(y_a)
428
- for k in counts_b:
429
- self.class_weights_b[k] = counts_b[k] / len(y_b)
426
+ for k in self.counts_a:
427
+ self.class_weights_a[k] = self.counts_a[k] / len(y_a)
428
+ for k in self.counts_b:
429
+ self.class_weights_b[k] = self.counts_b[k] / len(y_b)
430
430
  self.class_indices_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
431
431
  self.class_indices_b = [0] * (max(np.max(y_b), np.max(y_a)) + 1)
432
432
  for i in range(len(self.class_indices_a)):
433
433
  self.class_indices_a[i] = list(np.where(y_a == i)[0])
434
434
  for i in range(len(self.class_indices_b)):
435
435
  self.class_indices_b[i] = list(np.where(y_b == i)[0])
436
- self.a_only = counts_a.keys() - counts_b.keys()
437
- self.b_only = counts_b.keys() - counts_a.keys()
436
+ self.a_only = self.counts_a.keys() - self.counts_b.keys()
437
+ self.b_only = self.counts_b.keys() - self.counts_a.keys()
438
438
  self.use_with_replacement = (self.batch_size > len(self.class_weights_a))
439
439
 
440
440
  def _pop_leave_last(self, li):
@@ -451,14 +451,14 @@ class StratifiedDualBatchSampler:
451
451
  for i in range(self.num_batches):
452
452
  if i % 2 == 0:
453
453
  classes_a = list(WeightedRandomSampler(self.class_weights_a, self.batch_size, replacement=self.use_with_replacement))
454
- b_list = list(self.b_only)
454
+ b_list = list(self.counts_b)
455
455
  random.shuffle(b_list)
456
456
  classes_b = [ self._pop_leave_last(b_list) if a in self.a_only else a for a in classes_a]
457
457
  batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a]
458
458
  batch_indices_b = [ self.class_indices_b[c][next(samplers_b[c])] for c in classes_b]
459
459
  else:
460
460
  classes_b = list(WeightedRandomSampler(self.class_weights_b, self.batch_size, replacement=self.use_with_replacement))
461
- a_list = list(self.a_only)
461
+ a_list = list(self.counts_a)
462
462
  random.shuffle(a_list)
463
463
  classes_a = [ self._pop_leave_last(a_list) if b in self.b_only else b for b in classes_b]
464
464
  batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a]
tmnt/distribution.py CHANGED
@@ -168,7 +168,7 @@ class LogisticGaussianDistribution(BaseDistribution):
168
168
  class VonMisesDistribution(BaseDistribution):
169
169
 
170
170
  def __init__(self, enc_size, n_latent, kappa=100.0, dr=0.1, device='cpu'):
171
- super(VonMisesDistribution, self).__init__(enc_size, n_latent, device)
171
+ super(VonMisesDistribution, self).__init__(enc_size, n_latent, device, on_simplex=False)
172
172
  self.device = device
173
173
  self.kappa = kappa
174
174
  self.kld_v = torch.tensor(VonMisesDistribution._vmf_kld(self.kappa, self.n_latent), device=device)
tmnt/estimator.py CHANGED
@@ -1231,6 +1231,7 @@ class SeqBowEstimator(BaseEstimator):
1231
1231
  llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
1232
1232
  model = SeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
1233
1233
  vocab_size = len(self.vocabulary), use_pooling = (self.llm_model_name.startswith("sentence-transformers")),
1234
+ entropy_loss_coef=self.entropy_loss_coef,
1234
1235
  dropout=self.classifier_dropout)
1235
1236
  return model
1236
1237
 
@@ -1583,10 +1584,10 @@ class SeqBowEstimator(BaseEstimator):
1583
1584
 
1584
1585
  class SeqBowMetricEstimator(SeqBowEstimator):
1585
1586
 
1586
- def __init__(self, *args, sdml_smoothing_factor=0.3, non_scoring_index=-1, **kwargs):
1587
+ def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, non_scoring_index=-1, **kwargs):
1587
1588
  super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
1588
1589
  #self.loss_function = GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index)
1589
- self.loss_function = MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor)
1590
+ self.loss_function = MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
1590
1591
  self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
1591
1592
 
1592
1593
 
tmnt/modeling.py CHANGED
@@ -506,7 +506,7 @@ class BaseSeqBowVED(BaseVAE):
506
506
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
507
507
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
508
508
  else:
509
- model_output.last_hidden_state[:,0,:]
509
+ return model_output.last_hidden_state[:,0,:]
510
510
 
511
511
  def get_ordered_terms(self):
512
512
  """
@@ -551,8 +551,8 @@ class SeqBowVED(BaseSeqBowVED):
551
551
  super(SeqBowVED, self).__init__(*args, **kwargs)
552
552
  if self.has_classifier:
553
553
  self.classifier = torch.nn.Sequential()
554
- self.classifier.add_module("dr", nn.Dropout(self.dropout))
555
- self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes))
554
+ self.classifier.add_module("dr", nn.Dropout(self.dropout).to(self.device))
555
+ self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes).to(self.device))
556
556
 
557
557
  def forward(self, input_ids, attention_mask, bow=None): # pylint: disable=arguments-differ
558
558
  llm_output = self.llm(input_ids, attention_mask)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.0b20230910
3
+ Version: 0.7.0b20230912
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -2,12 +2,12 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/bert_handling.py,sha256=4l78pzLjK0rbsGa3YxCsfVEndJPzaXTaj_928ZPZfSk,24677
3
3
  tmnt/common_params.py,sha256=uNWs1UuaTx6xlxbS5LailnuXhuIKCg8kCJqxes-kAGY,2547
4
4
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
5
- tmnt/data_loading.py,sha256=msX-l7ov0jynqxUvhPFe-5b6zc5FOlLiF1ruJXOWMrU,18485
6
- tmnt/distribution.py,sha256=-rKCVzpdu8P2NfBOB7QdZoFYGcCT3Q9K9x6fK0tRwew,8364
7
- tmnt/estimator.py,sha256=Fk0JQ0mmU9HTOrYf_q7jXsXIN1KKM2n1kRK_sinQmp4,78555
5
+ tmnt/data_loading.py,sha256=Fnn3Pdrw16e6IR_QEPusiUfSCrHlk-3ddKeyzQW_5JE,18569
6
+ tmnt/distribution.py,sha256=JrJe2HaF2uub0S8RxLAjSykg_AF6atvgUhNWbbaxaMo,8382
7
+ tmnt/estimator.py,sha256=lMcuVvPIBiI1ChUpllCFR6ygWnclpcOcYA-ht5r4W3s,78680
8
8
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
9
9
  tmnt/inference.py,sha256=Hc0PRmUBLr9YbfqAGyw6-1BQqiwUUEtS8ehsWkr7QJk,18399
10
- tmnt/modeling.py,sha256=mWAlyqTYjZrkztYPrlvb24dtuvyMQ7v7mzyQNLOBs6o,32901
10
+ tmnt/modeling.py,sha256=Zw6F3PZcaJsUZVuyk-sQIMNWMaqcjQH5cetZTmcqI7g,32940
11
11
  tmnt/selector.py,sha256=DWJlbdWKNxJmyLI9IRxCa3FDmaCilxTDzNfIf5mpBqc,9578
12
12
  tmnt/trainer.py,sha256=xaJtU_vHAPbos9q86NNdnwz7kpUF5BxGyTWbDG_NBA0,25802
13
13
  tmnt/classifier/__init__.py,sha256=1gLyJjCMHmnWdf-J4gnRs4uhbebtvCs9RgnZze1HTXY,67
@@ -30,9 +30,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
30
30
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
31
31
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
32
32
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
33
- tmnt-0.7.0b20230910.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
34
- tmnt-0.7.0b20230910.dist-info/METADATA,sha256=cLvhJfcBDDppP-t0GF0IPmayQ6T6NfeHkKj153d7lCk,997
35
- tmnt-0.7.0b20230910.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
36
- tmnt-0.7.0b20230910.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
37
- tmnt-0.7.0b20230910.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
38
- tmnt-0.7.0b20230910.dist-info/RECORD,,
33
+ tmnt-0.7.0b20230912.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
34
+ tmnt-0.7.0b20230912.dist-info/METADATA,sha256=vx4UrdxOGQ2BPb9o-5mipGun99SVmyw0DOL4aS58gOU,997
35
+ tmnt-0.7.0b20230912.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
36
+ tmnt-0.7.0b20230912.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
37
+ tmnt-0.7.0b20230912.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
38
+ tmnt-0.7.0b20230912.dist-info/RECORD,,