tmnt 0.7.56__py3-none-any.whl → 0.7.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -25,9 +25,7 @@ from typing import List, Tuple, Dict, Optional, Union, NoReturn
25
25
 
26
26
  import torch
27
27
  from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler
28
- from torchtext.vocab import vocab as build_vocab
29
- from torchtext.data.utils import get_tokenizer
30
- from torchtext.vocab import build_vocab_from_iterator
28
+ from tmnt.utils.vocab import build_vocab
31
29
  from transformers import DistilBertTokenizer, DistilBertModel, AutoTokenizer, AutoModel, DistilBertTokenizer, BertModel, DistilBertModel, OpenAIGPTModel
32
30
  from sklearn.model_selection import StratifiedKFold
33
31
 
@@ -42,7 +40,8 @@ llm_catalog = {
42
40
  'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
43
41
  'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
44
42
  'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
45
- 'Alibaba-NLP/gte-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
43
+ 'Alibaba-NLP/gte-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
44
+ 'intfloat/multilingual-e5-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
46
45
  ## add more model options here ...
47
46
  }
48
47
 
@@ -58,17 +57,18 @@ def get_llm_model(model_name):
58
57
  tok_fn, model_fn = llm_catalog.get(model_name, ((AutoTokenizer.from_pretrained, AutoModel.from_pretrained)))
59
58
  return model_fn(model_name, trust_remote_code=True)
60
59
 
61
- def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'):
60
+ def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, bow_target_texts=None,
61
+ shuffle=False, device='cpu'):
62
62
  label_pipeline = lambda x: label_map.get(x, 0)
63
63
  text_pipeline = get_llm_tokenizer(llm_name)
64
64
 
65
65
  def collate_batch(batch):
66
66
  label_list, text_list, mask_list, bow_list = [], [], [], []
67
- for (_label, _text) in batch:
67
+ for (_label, _text, _target_text) in batch:
68
68
  label_list.append(label_pipeline(_label))
69
69
  tokenized_result = text_pipeline(_text, return_tensors='pt', padding='max_length',
70
70
  max_length=max_len, truncation=True)
71
- bag_of_words,_ = bow_vectorizer.transform([_text])
71
+ bag_of_words,_ = bow_vectorizer.transform([_target_text])
72
72
  processed_text = tokenized_result['input_ids']
73
73
  mask = tokenized_result['attention_mask']
74
74
  mask_list.append(mask)
@@ -79,10 +79,16 @@ def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batc
79
79
  mask_list = torch.vstack(mask_list)
80
80
  bow_list = torch.vstack([ sparse_coo_to_tensor(bow_vec.tocoo()) for bow_vec in bow_list ])
81
81
  return label_list.to(device), text_list.to(device), mask_list.to(device), bow_list.to(device)
82
- return DataLoader(data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
82
+ if bow_target_texts is not None:
83
+ assert len(bow_target_texts) == len(data)
84
+ full_data = [ (label, txt, alt_text) for ((label, txt), alt_text) in zip(data, bow_target_texts)]
85
+ else:
86
+ full_data = [ (label, txt, txt) for (label, txt) in data]
87
+ return DataLoader(full_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
83
88
 
84
- def get_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'):
85
- return SingletonWrapperLoader(get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=shuffle, device=device))
89
+ def get_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, bow_target_texts=None, shuffle=False, device='cpu'):
90
+ return SingletonWrapperLoader(get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len,
91
+ bow_target_texts=bow_target_texts, shuffle=shuffle, device=device))
86
92
 
87
93
 
88
94
  def get_llm_paired_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b,
tmnt/distribution.py CHANGED
@@ -14,26 +14,20 @@ from torch.nn import Sequential
14
14
  import torch
15
15
  from scipy import special as sp
16
16
  import torch
17
+ from typing import Callable, Literal, Optional, Tuple, TypeVar, Union
18
+ from tmnt.sparse.modeling import TopKEncoder
17
19
 
18
20
 
19
21
  __all__ = ['BaseDistribution', 'GaussianDistribution', 'GaussianUnitVarDistribution', 'LogisticGaussianDistribution',
20
22
  'VonMisesDistribution']
21
23
 
22
-
23
24
  class BaseDistribution(nn.Module):
24
25
 
25
- def __init__(self, enc_size, n_latent, device, on_simplex=False):
26
+ def __init__(self, enc_size, n_latent, device, on_simplex=True):
26
27
  super(BaseDistribution, self).__init__()
27
28
  self.n_latent = n_latent
28
29
  self.enc_size = enc_size
29
30
  self.device = device
30
- self.mu_encoder = nn.Linear(enc_size, n_latent).to(device)
31
- #self.mu_encoder = Sequential(self.mu_proj, nn.Softplus().to(device))
32
- self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
33
- self.softmax = nn.Softmax(dim=1).to(device)
34
- self.softplus = nn.Softplus().to(device)
35
- self.on_simplex = on_simplex
36
- #self.mu_bn.collect_params().setattr('grad_req', 'null')
37
31
 
38
32
  ## this is required by most priors
39
33
  def _get_gaussian_sample(self, mu, lv, batch_size):
@@ -48,11 +42,25 @@ class BaseDistribution(nn.Module):
48
42
 
49
43
  def get_mu_encoding(self, data, include_bn):
50
44
  raise NotImplemented
45
+
46
+ def freeze_pre_encoder(self) -> None:
47
+ raise NotImplemented
51
48
 
49
+ def unfreeze_pre_encoder(self) -> None:
50
+ raise NotImplemented
52
51
 
53
52
 
53
+ class SimpleDistribution(BaseDistribution):
54
+ def __init__(self, enc_size, n_latent, device, on_simplex=False):
55
+ super(SimpleDistribution, self).__init__(enc_size, n_latent, device, on_simplex=on_simplex)
56
+ self.mu_encoder = nn.Linear(enc_size, n_latent).to(device)
57
+ self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
58
+ self.softmax = nn.Softmax(dim=1).to(device)
59
+ self.softplus = nn.Softplus().to(device)
60
+ self.on_simplex = on_simplex
61
+
54
62
 
55
- class GaussianDistribution(BaseDistribution):
63
+ class GaussianDistribution(SimpleDistribution):
56
64
  """Gaussian latent distribution with diagnol co-variance.
57
65
 
58
66
  Parameters:
@@ -99,7 +107,7 @@ class GaussianDistribution(BaseDistribution):
99
107
 
100
108
 
101
109
 
102
- class GaussianUnitVarDistribution(BaseDistribution):
110
+ class GaussianUnitVarDistribution(SimpleDistribution):
103
111
  """Gaussian latent distribution with fixed unit variance.
104
112
 
105
113
  Parameters:
@@ -142,7 +150,7 @@ class GaussianUnitVarDistribution(BaseDistribution):
142
150
  return mu
143
151
 
144
152
 
145
- class LogisticGaussianDistribution(BaseDistribution):
153
+ class LogisticGaussianDistribution(SimpleDistribution):
146
154
  """Logistic normal/Gaussian latent distribution with specified prior
147
155
 
148
156
  Parameters:
@@ -199,7 +207,7 @@ class LogisticGaussianDistribution(BaseDistribution):
199
207
  return mu
200
208
 
201
209
 
202
- class VonMisesDistribution(BaseDistribution):
210
+ class VonMisesDistribution(SimpleDistribution):
203
211
 
204
212
  def __init__(self, enc_size, n_latent, kappa=100.0, dr=0.1, device='cpu'):
205
213
  super(VonMisesDistribution, self).__init__(enc_size, n_latent, device, on_simplex=False)
@@ -239,7 +247,7 @@ class VonMisesDistribution(BaseDistribution):
239
247
 
240
248
 
241
249
 
242
- class Projection(BaseDistribution):
250
+ class Projection(SimpleDistribution):
243
251
 
244
252
  def __init__(self, enc_size, n_latent, device='cpu'):
245
253
  super(Projection, self).__init__(enc_size, n_latent, device)
@@ -265,6 +273,85 @@ class Projection(BaseDistribution):
265
273
  return enc
266
274
 
267
275
 
268
-
276
+ class ConceptLogisticGaussianDistribution(BaseDistribution):
277
+ """Sparse concept encoding with Logistic normal/Gaussian latent distribution with specified prior
278
+
279
+ Parameters:
280
+ n_latent (int): Dimentionality of the latent distribution
281
+ device (device): Torch computational context (cpu or gpu[id])
282
+ dr (float): Dropout value for dropout applied post sample. optional (default = 0.2)
283
+ alpha (float): Value the determines prior variance as 1/alpha - (2/n_latent) + 1/(n_latent^2)
284
+ """
285
+ def __init__(self, enc_size, n_latent, sparse_encoder: TopKEncoder, device='cpu', dr=0.1, alpha=1.0):
286
+ super(ConceptLogisticGaussianDistribution, self).__init__(enc_size, n_latent, device, on_simplex=True)
287
+ self.n_latent = n_latent
288
+ self.enc_size = enc_size
289
+ self.device = device
290
+ self.sparse_encoder = sparse_encoder.to(device)
291
+ self.n_concepts = sparse_encoder.get_dict_size()
292
+ self.sparse_to_mu = nn.Linear(self.n_concepts, n_latent).to(device)
293
+ self.sparse_bn = nn.BatchNorm1d(self.n_concepts, momentum=0.8, eps=0.0001).to(device)
294
+ self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
295
+ self.softmax = nn.Softmax(dim=1).to(device)
296
+ self.on_simplex = True
297
+ self.alpha = alpha
298
+
299
+ prior_var = 1 / self.alpha - (2.0 / n_latent) + 1 / (self.n_latent * self.n_latent)
300
+ self.prior_var = torch.tensor([prior_var], device=device)
301
+ self.prior_logvar = torch.tensor([math.log(prior_var)], device=device)
302
+
303
+ ## NOTE: the weights to model the log-variance are separate but the sparse encoder is shared
304
+ ## between the lv_encoder and mu_encoder (above)
305
+ self.sparse_to_lv = nn.Linear(self.n_concepts, n_latent).to(device)
306
+ self.lv_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.001).to(device)
307
+ self.post_sample_dr_o = nn.Dropout(dr)
308
+
309
+
310
+ def freeze_pre_encoder(self):
311
+ self.sparse_encoder.W_enc.requires_grad = False
312
+ self.sparse_encoder.b_enc.requires_grad = False
313
+
314
+ def unfreeze_pre_encoder(self):
315
+ self.sparse_encoder.W_enc.requires_grad = True
316
+ self.sparse_encoder.b_enc.requires_grad = True
269
317
 
318
+ def _get_kl_term(self, mu, lv):
319
+ posterior_var = torch.exp(lv)
320
+ delta = mu
321
+ dt = torch.div(delta * delta, self.prior_var)
322
+ v_div = torch.div(posterior_var, self.prior_var)
323
+ lv_div = self.prior_logvar - lv
324
+ return (0.5 * (torch.sum((v_div + dt + lv_div), 1) - self.n_latent)).to(self.device)
325
+
326
+ def forward(self, data, batch_size):
327
+ """Generate a sample according to the logistic Gaussian latent distribution given the encoder outputs
328
+ """
329
+ _, sparse, _, _, _ = self.sparse_encoder(data)
330
+ #sparse_bn = self.sparse_bn(sparse)
331
+ mu = self.sparse_to_mu(sparse)
332
+ mu_bn = self.mu_bn(mu)
333
+ lv = self.sparse_to_lv(sparse)
334
+ lv_bn = self.lv_bn(lv)
335
+ z_p = self._get_gaussian_sample(mu_bn, lv_bn, batch_size)
336
+ KL = self._get_kl_term(mu, lv)
337
+ z = self.post_sample_dr_o(z_p)
338
+ return self.softmax(z), KL
270
339
 
340
+ def get_sparse_encoding(self, data):
341
+ _, sparse, _, _, _ = self.sparse_encoder(data)
342
+ return sparse
343
+
344
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
345
+ """Provide the distribution mean as the natural result of running the full encoder
346
+
347
+ Parameters:
348
+ data (:class:`mxnet.ndarray.NDArray`): Output of pre-latent encoding layers
349
+ Returns:
350
+ encoding (:class:`mxnet.ndarray.NDArray`): Encoding vector representing unnormalized topic proportions
351
+ """
352
+ _, sparse, _, _, _ = self.sparse_encoder(data)
353
+ enc = self.sparse_to_mu(sparse)
354
+ if include_bn:
355
+ enc = self.mu_bn(enc)
356
+ mu = self.softmax(enc) if normalize else enc
357
+ return mu
tmnt/estimator.py CHANGED
@@ -21,6 +21,7 @@ from tmnt.modeling import BowVAEModel, SeqBowVED, BaseVAE
21
21
  from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
22
22
  from tmnt.eval_npmi import EvaluateNPMI
23
23
  from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
24
+ from tmnt.utils.vocab import Vocab
24
25
 
25
26
  ## evaluation routines
26
27
  from torcheval.metrics import MultilabelAUPRC, MulticlassAUPRC
@@ -38,7 +39,6 @@ import pickle
38
39
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
39
40
 
40
41
  import torch
41
- import torchtext
42
42
  from torch.utils.data import Dataset, DataLoader
43
43
  from tqdm import tqdm
44
44
 
@@ -249,7 +249,7 @@ class BaseBowEstimator(BaseEstimator):
249
249
  device = device)
250
250
 
251
251
  @classmethod
252
- def from_config(cls, config: Union[str, dict], vocabulary: Union[str, torchtext.vocab.Vocab],
252
+ def from_config(cls, config: Union[str, dict], vocabulary: Union[str, Vocab],
253
253
  n_labels: int = 0,
254
254
  coherence_coefficient: float = 8.0,
255
255
  coherence_via_encoder: bool = False,
@@ -943,12 +943,13 @@ class SeqBowEstimator(BaseEstimator):
943
943
  self._bow_matrix = None
944
944
  self.entropy_loss_coef = entropy_loss_coef
945
945
  self.pool_encoder = pool_encoder
946
+ self.freeze_pre_encoder_weights = False
946
947
 
947
948
 
948
949
  @classmethod
949
950
  def from_config(cls,
950
951
  config: Union[str, dict],
951
- vocabulary: torchtext.vocab.Vocab,
952
+ vocabulary: Vocab,
952
953
  log_interval: int = 1,
953
954
  pretrained_param_file: Optional[str] = None,
954
955
  n_labels: Optional[int] = None,
@@ -974,7 +975,7 @@ class SeqBowEstimator(BaseEstimator):
974
975
  raise Exception("Invalid Json Configuration File")
975
976
  ldist_def = config['latent_distribution']
976
977
  llm_model_name = config['llm_model_name']
977
- model = torch.load(pretrained_param_file, map_location=device)
978
+ model = torch.load(pretrained_param_file, map_location=device, weights_only=False)
978
979
 
979
980
  latent_distribution = model.latent_distribution
980
981
  estimator = cls(llm_model_name = llm_model_name,
@@ -1006,13 +1007,16 @@ class SeqBowEstimator(BaseEstimator):
1006
1007
  config_file = os.path.join(model_dir, 'model.config')
1007
1008
  with open(config_file) as f:
1008
1009
  config = json.loads(f.read())
1009
- vocab = torch.load(vocab_file)
1010
+ vocab = torch.load(vocab_file, weights_only=False)
1010
1011
  return cls.from_config(config,
1011
1012
  vocabulary = vocab,
1012
1013
  log_interval = log_interval,
1013
1014
  pretrained_param_file = param_file,
1014
1015
  device = device)
1015
1016
 
1017
+ def freeze_pre_encoder(self):
1018
+ self.freeze_pre_encoder_weights = True
1019
+
1016
1020
 
1017
1021
  def _get_model_bias_initialize(self, train_data):
1018
1022
  model = self._get_model()
@@ -1030,6 +1034,7 @@ class SeqBowEstimator(BaseEstimator):
1030
1034
  entropy_loss_coef=self.entropy_loss_coef,
1031
1035
  dropout=self.classifier_dropout)
1032
1036
  return model
1037
+
1033
1038
 
1034
1039
  def _get_config(self):
1035
1040
  config = {}
@@ -1185,8 +1190,10 @@ class SeqBowEstimator(BaseEstimator):
1185
1190
  if self.model is None or not self.warm_start:
1186
1191
  self.model = self._get_model_bias_initialize(train_data)
1187
1192
 
1188
- model = self.model
1193
+ if self.freeze_pre_encoder_weights:
1194
+ self.model.freeze_pre_encoder()
1189
1195
 
1196
+ model = self.model
1190
1197
  accumulate = False
1191
1198
  v_res = None
1192
1199
 
@@ -1268,7 +1275,8 @@ class SeqBowEstimator(BaseEstimator):
1268
1275
  update_loss_details(total_ls_2, elbo_ls_2, red_ls_2, None)
1269
1276
 
1270
1277
  if not accumulate or (batch_id + 1) % accumulate == 0:
1271
- torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1278
+ if not self.freeze_pre_encoder_weights:
1279
+ torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1272
1280
  optimizer.step()
1273
1281
  dec_optimizer.step()
1274
1282
  lr_scheduler.step()
tmnt/inference.py CHANGED
@@ -18,8 +18,9 @@ from tmnt.utils.recalibrate import recalibrate_scores
18
18
  from sklearn.datasets import load_svmlight_file
19
19
  from functools import partial
20
20
  from tmnt.data_loading import get_llm_tokenizer
21
-
22
21
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
22
+ from scipy.sparse import csr_matrix
23
+ from tmnt.distribution import ConceptLogisticGaussianDistribution
23
24
 
24
25
 
25
26
  MAX_DESIGN_MATRIX = 250000000
@@ -347,6 +348,9 @@ class MetricSeqVEDInferencer(SeqVEDInferencer):
347
348
 
348
349
 
349
350
 
351
+
352
+
353
+
350
354
 
351
355
 
352
356
 
tmnt/modeling.py CHANGED
@@ -45,6 +45,9 @@ class BaseVAE(nn.Module):
45
45
  t_npmi_mat = torch.Tensor(npmi_mat).to(self.device)
46
46
  self.npmi_with_diversity_loss = NPMILossWithDiversity(t_npmi_mat, device=self.device, npmi_lambda=npmi_lambda, npmi_scale=npmi_scale)
47
47
 
48
+ def freeze_pre_encoder(self):
49
+ pass
50
+
48
51
  def get_ordered_terms(self):
49
52
  """
50
53
  Returns the top K terms for each topic based on sensitivity analysis. Terms whose
@@ -56,7 +59,6 @@ class BaseVAE(nn.Module):
56
59
  sorted_j = jacobian.argsort(dim=0, descending=True)
57
60
  return sorted_j.cpu().numpy()
58
61
 
59
-
60
62
  def get_topic_vectors(self):
61
63
  """
62
64
  Returns unnormalized topic vectors
@@ -126,7 +128,8 @@ class BowVAEModel(BaseVAE):
126
128
 
127
129
  def _init_weights(self, module):
128
130
  if isinstance(module, torch.nn.Linear):
129
- torch.nn.init.xavier_uniform_(module.weight.data)
131
+ torch.nn.init.kaiming_uniform_(module.weight.data)
132
+ #torch.nn.init.xavier_uniform_(module.weight.data)
130
133
 
131
134
 
132
135
  def _get_encoder(self, dims, dr=0.1):
@@ -360,7 +363,7 @@ class CoherenceRegularizer(nn.Module):
360
363
  class BaseSeqBowVED(BaseVAE):
361
364
  def __init__(self,
362
365
  llm,
363
- latent_dist,
366
+ latent_dist: BaseDistribution,
364
367
  num_classes=0,
365
368
  dropout=0.0,
366
369
  vocab_size=2000,
@@ -401,6 +404,11 @@ class BaseSeqBowVED(BaseVAE):
401
404
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
402
405
  else:
403
406
  return model_output.last_hidden_state[:,0,:]
407
+
408
+ def freeze_pre_encoder(self):
409
+ for p in self.llm.parameters():
410
+ p.requires_grad = False
411
+ self.latent_distribution.freeze_pre_encoder()
404
412
 
405
413
  def get_ordered_terms(self):
406
414
  """
@@ -447,6 +455,7 @@ class SeqBowVED(BaseSeqBowVED):
447
455
  self.classifier = torch.nn.Sequential()
448
456
  self.classifier.add_module("dr", nn.Dropout(self.dropout).to(self.device))
449
457
  self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes).to(self.device))
458
+
450
459
 
451
460
  def forward(self, input_ids, attention_mask, bow=None): # pylint: disable=arguments-differ
452
461
  llm_output = self.llm(input_ids, attention_mask)
@@ -462,9 +471,8 @@ class SeqBowVED(BaseSeqBowVED):
462
471
  classifier_outputs = self.classifier(z_mu)
463
472
  else:
464
473
  classifier_outputs = None
465
- redundancy_loss = entropy_loss
466
474
  ii_loss = self.add_npmi_and_diversity_loss(elbo)
467
- redundancy_loss = entropy_loss #self.get_redundancy_penalty()
475
+ redundancy_loss = ii_loss #self.get_redundancy_penalty()
468
476
  return ii_loss, rec_loss, KL_loss, redundancy_loss, classifier_outputs
469
477
 
470
478
 
@@ -6,17 +6,12 @@ Copyright (c) 2019-2021 The MITRE Corporation.
6
6
  import io
7
7
  import os
8
8
  import json
9
- import torchtext
10
- from torchtext.vocab import vocab as build_vocab
9
+ from tmnt.utils.vocab import Vocab, build_vocab
11
10
  import glob
12
11
  from multiprocessing import Pool, cpu_count
13
12
  from mantichora import mantichora
14
13
  from atpbar import atpbar
15
14
  import collections
16
- import threading
17
- import logging
18
- import threading
19
- import scipy
20
15
  import scipy.sparse as sp
21
16
  import numpy as np
22
17
  from queue import Queue
@@ -25,9 +20,14 @@ from sklearn.datasets import dump_svmlight_file
25
20
  from tmnt.preprocess import BasicTokenizer
26
21
  from typing import List, Dict, Optional, Any, Tuple
27
22
  from collections import OrderedDict
23
+ from sklearn.utils import check_array
24
+ from sklearn.preprocessing import normalize
25
+ from sklearn.feature_extraction.text import TfidfTransformer
26
+ from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted
27
+ from sklearn.feature_extraction._stop_words import ENGLISH_STOP_WORDS
28
28
 
29
- __all__ = ['TMNTVectorizer']
30
29
 
30
+ __all__ = ['TMNTVectorizer', 'CTFIDFVectorizer']
31
31
 
32
32
  class TMNTVectorizer(object):
33
33
 
@@ -57,10 +57,12 @@ class TMNTVectorizer(object):
57
57
  def __init__(self, text_key: str = 'body', label_key: Optional[str] = None, min_doc_size: int = 1,
58
58
  label_remap: Optional[Dict[str,str]] = None,
59
59
  json_out_dir: Optional[str] = None, vocab_size: int = 2000, file_pat: str = '*.json',
60
- encoding: str = 'utf-8', initial_vocabulary: Optional[torchtext.vocab.Vocab] = None,
60
+ encoding: str = 'utf-8', initial_vocabulary: Optional[Vocab] = None,
61
61
  additional_feature_keys: List[str] = None, stop_word_file: str = None,
62
62
  split_char: str = ',',
63
63
  max_ws_tokens: int = -1,
64
+ source_key: Optional[str] = None,
65
+ source_json: Optional[str] = None,
64
66
  count_vectorizer_kwargs: Dict[str, Any] = {'max_df':0.95, 'min_df':0.0, 'stop_words':'english'}):
65
67
  self.encoding = encoding
66
68
  self.max_ws_tokens = max_ws_tokens
@@ -78,12 +80,53 @@ class TMNTVectorizer(object):
78
80
  self.cv_kwargs = self._update_count_vectorizer_args(count_vectorizer_kwargs, stop_word_file)
79
81
  if not 'token_pattern' in self.cv_kwargs:
80
82
  self.cv_kwargs['token_pattern'] = r'\b[A-Za-z][A-Za-z]+\b'
83
+ if source_key and source_json:
84
+ source_terms = self._get_source_specific_terms(source_json, 10, text_key, source_key,
85
+ {'token_pattern': self.cv_kwargs['token_pattern'],
86
+ 'stop_words': self.cv_kwargs['stop_words'],
87
+ 'max_df': 1.0, 'min_df':0.0})
88
+ stop_words = set(source_terms)
89
+ stop_words.update(set(ENGLISH_STOP_WORDS))
90
+ self.cv_kwargs['stop_words'] = frozenset(stop_words)
81
91
  self.vectorizer = CountVectorizer(max_features=self.vocab_size,
82
92
  vocabulary=(initial_vocabulary.get_itos() if initial_vocabulary else None),
83
93
  **self.cv_kwargs)
84
94
  self.label_map = {}
85
95
 
86
96
 
97
+ def _get_source_specific_terms(self, json_file, k: int, text_key: str, source_key: str, cv_kwargs):
98
+ by_source = {}
99
+ with io.open(json_file) as fp:
100
+ for l in fp:
101
+ js = json.loads(l)
102
+ txt = js[text_key]
103
+ src = js[source_key]
104
+ if src not in by_source:
105
+ by_source[src] = []
106
+ by_source[src].append(txt)
107
+ docs_by_source = [''.join(txts) for txts in by_source.values()]
108
+ print(cv_kwargs)
109
+ count_vectorizer = CountVectorizer(**cv_kwargs)
110
+ count = count_vectorizer.fit_transform(docs_by_source)
111
+ ctfidf = CTFIDFVectorizer().fit_transform(count)
112
+ tok_to_idx = list(count_vectorizer.vocabulary_.items())
113
+ tok_to_idx.sort(key = lambda x: x[1])
114
+ ordered_vocab = OrderedDict([ (k,1) for (k,_) in tok_to_idx ])
115
+ ovocab = build_vocab(ordered_vocab)
116
+ per_source_tokens = []
117
+ for i in range(ctfidf.shape[0]):
118
+ ts = ctfidf[i].toarray().squeeze()
119
+ per_source_tokens.append(ovocab.lookup_tokens((-ts).argsort()[:k]))
120
+ final_tokens_intersect = set(per_source_tokens[0])
121
+ final_tokens_union = set(per_source_tokens[0])
122
+ for src_tokens in per_source_tokens:
123
+ final_tokens_intersect.intersection_update(src_tokens)
124
+ final_tokens_union.update(src_tokens)
125
+ res = final_tokens_union - final_tokens_intersect
126
+ print("Removed terms = {}".format(res))
127
+ return final_tokens_union - final_tokens_intersect
128
+
129
+
87
130
 
88
131
  def _update_count_vectorizer_args(self, cv_kwargs: Dict[str, Any], stop_word_file: str) -> Dict[str, Any]:
89
132
  if stop_word_file:
@@ -113,11 +156,11 @@ class TMNTVectorizer(object):
113
156
  return list(set(wds))
114
157
 
115
158
 
116
- def get_vocab(self) -> torchtext.vocab.Vocab:
117
- """Returns the Torchtext vocabulary associated with the vectorizer
159
+ def get_vocab(self) -> Vocab:
160
+ """Returns the vocabulary associated with the vectorizer
118
161
 
119
162
  Returns:
120
- Torchtext vocabulary
163
+ vocabulary
121
164
  """
122
165
  if self.vocab is not None:
123
166
  return self.vocab
@@ -375,3 +418,78 @@ class TMNTVectorizer(object):
375
418
  y = self._get_ys_dir(json_dir)
376
419
  return X, y
377
420
 
421
+
422
+
423
+ class CTFIDFVectorizer(TfidfTransformer):
424
+ def __init__(self, *args, **kwargs):
425
+ super(CTFIDFVectorizer, self).__init__(*args, **kwargs)
426
+ self._idf_diag = None
427
+
428
+ def fit(self, X: sp.csr_matrix):
429
+ """Learn the idf vector (global term weights)
430
+
431
+ Parameters
432
+ ----------
433
+ X : sparse matrix of shape n_samples, n_features)
434
+ A matrix of term/token counts.
435
+
436
+ """
437
+
438
+ # Prepare input
439
+ X = check_array(X, accept_sparse=('csr', 'csc'))
440
+ if not sp.issparse(X):
441
+ X = sp.csr_matrix(X)
442
+ dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
443
+
444
+ # Calculate IDF scores
445
+ _, n_features = X.shape
446
+ df = np.squeeze(np.asarray(X.sum(axis=0)))
447
+ avg_nr_samples = int(X.sum(axis=1).mean())
448
+ idf = np.log(avg_nr_samples / df)
449
+ self._idf_diag = sp.diags(idf, offsets=0,
450
+ shape=(n_features, n_features),
451
+ format='csr',
452
+ dtype=dtype)
453
+ setattr(self, 'idf_', True)
454
+ return self
455
+
456
+ def transform(self, X: sp.csr_matrix, copy=True) -> sp.csr_matrix:
457
+ """Transform a count-based matrix to c-TF-IDF
458
+
459
+ Parameters
460
+ ----------
461
+ X : sparse matrix of (n_samples, n_features)
462
+ a matrix of term/token counts
463
+
464
+ Returns
465
+ -------
466
+ vectors : sparse matrix of shape (n_samples, n_features)
467
+
468
+ """
469
+
470
+ # Prepare input
471
+ X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES, copy=copy)
472
+ if not sp.issparse(X):
473
+ X = sp.csr_matrix(X, dtype=np.float64)
474
+
475
+ _, n_features = X.shape
476
+
477
+ # idf_ being a property, the automatic attributes detection
478
+ # does not work as usual and we need to specify the attribute
479
+ # name:
480
+ check_is_fitted(self, attributes=["idf_"],
481
+ msg='idf vector is not fitted')
482
+
483
+ # Check if expected nr features is found
484
+ expected_n_features = self._idf_diag.shape[0]
485
+ if n_features != expected_n_features:
486
+ raise ValueError("Input has n_features=%d while the model"
487
+ " has been trained with n_features=%d" % (
488
+ n_features, expected_n_features))
489
+
490
+ X = X * self._idf_diag
491
+
492
+ if self.norm:
493
+ X = normalize(X, axis=1, norm='l1', copy=False)
494
+
495
+ return X
tmnt/utils/vocab.py ADDED
@@ -0,0 +1,126 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, List, Optional, Iterable, OrderedDict
4
+ from collections import Counter
5
+
6
+ class Vocab(nn.Module):
7
+ r"""Creates a vocab object which maps tokens to indices.
8
+
9
+ Args:
10
+ vocab (torch.classes.torchtext.Vocab or torchtext._torchtext.Vocab): a cpp vocab object.
11
+ """
12
+
13
+ def __init__(self, stoi: Dict):
14
+ super(Vocab, self).__init__()
15
+ self.stoi = stoi
16
+ self.itos = list(stoi.keys())
17
+
18
+ def forward(self, tokens: List[str]) -> List[int]:
19
+ r"""Calls the `lookup_indices` method
20
+
21
+ Args:
22
+ tokens: a list of tokens used to lookup their corresponding `indices`.
23
+
24
+ Returns:
25
+ The indices associated with a list of `tokens`.
26
+ """
27
+ return [self.stoi[t] for t in tokens]
28
+
29
+ def __len__(self) -> int:
30
+ r"""
31
+ Returns:
32
+ The length of the vocab.
33
+ """
34
+ return len(self.stoi)
35
+
36
+ def __contains__(self, token: str) -> bool:
37
+ r"""
38
+ Args:
39
+ token: The token for which to check the membership.
40
+
41
+ Returns:
42
+ Whether the token is member of vocab or not.
43
+ """
44
+ return self.stoi.__contains__(token)
45
+
46
+ def __getitem__(self, token: str) -> int:
47
+ r"""
48
+ Args:
49
+ token: The token used to lookup the corresponding index.
50
+
51
+ Returns:
52
+ The index corresponding to the associated token.
53
+ """
54
+ return self.stoi[token]
55
+
56
+ def insert_token(self, token: str, index: int) -> None:
57
+ r"""
58
+ Args:
59
+ token: The token used to lookup the corresponding index.
60
+ index: The index corresponding to the associated token.
61
+ Raises:
62
+ RuntimeError: If `index` is not in range [0, Vocab.size()] or if `token` already exists in the vocab.
63
+ """
64
+ if not token in self.stoi:
65
+ self.stoi[token] = index
66
+ self.itos[index] = token
67
+
68
+ def lookup_token(self, index: int) -> str:
69
+ r"""
70
+ Args:
71
+ index: The index corresponding to the associated token.
72
+
73
+ Returns:
74
+ token: The token used to lookup the corresponding index.
75
+
76
+ Raises:
77
+ RuntimeError: If `index` not in range [0, itos.size()).
78
+ """
79
+ return self.itos[index]
80
+
81
+ def lookup_tokens(self, indices: List[int]) -> List[str]:
82
+ r"""
83
+ Args:
84
+ indices: The `indices` used to lookup their corresponding`tokens`.
85
+
86
+ Returns:
87
+ The `tokens` associated with `indices`.
88
+
89
+ Raises:
90
+ RuntimeError: If an index within `indices` is not int range [0, itos.size()).
91
+ """
92
+ return [ self.itos[i] for i in indices]
93
+
94
+ def lookup_indices(self, tokens: List[str]) -> List[int]:
95
+ r"""
96
+ Args:
97
+ tokens: the tokens used to lookup their corresponding `indices`.
98
+
99
+ Returns:
100
+ The 'indices` associated with `tokens`.
101
+ """
102
+ return [ self.stoi[t] for t in tokens ]
103
+
104
+ def get_stoi(self) -> Dict[str, int]:
105
+ r"""
106
+ Returns:
107
+ Dictionary mapping tokens to indices.
108
+ """
109
+ return self.stoi
110
+
111
+ def get_itos(self) -> List[str]:
112
+ r"""
113
+ Returns:
114
+ List mapping indices to tokens.
115
+ """
116
+ return self.itos
117
+
118
+
119
+
120
+ def build_vocab(
121
+ odict: OrderedDict
122
+ ) -> Vocab:
123
+ """
124
+ """
125
+ dict_by_position = dict(zip(odict.keys(), range(0,len(odict))))
126
+ return Vocab(dict_by_position)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: tmnt
3
- Version: 0.7.56
3
+ Version: 0.7.58
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -14,6 +14,7 @@ Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
15
  License-File: NOTICE
16
16
  Requires-Dist: optuna
17
+ Requires-Dist: datasets
17
18
  Requires-Dist: mantichora>=0.9.5
18
19
  Requires-Dist: transformers[torch]
19
20
  Requires-Dist: torcheval
@@ -32,7 +33,17 @@ Requires-Dist: numba
32
33
  Requires-Dist: scipy==1.12.0
33
34
  Requires-Dist: tabulate>=0.8.7
34
35
  Requires-Dist: torch>=2.1.2
35
- Requires-Dist: torchtext>=0.13.0
36
+ Dynamic: author
37
+ Dynamic: author-email
38
+ Dynamic: classifier
39
+ Dynamic: description
40
+ Dynamic: description-content-type
41
+ Dynamic: home-page
42
+ Dynamic: license
43
+ Dynamic: license-file
44
+ Dynamic: requires-dist
45
+ Dynamic: requires-python
46
+ Dynamic: summary
36
47
 
37
48
  The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
38
49
  topic models as neural network-based variational auto-encoders.
@@ -1,14 +1,14 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
- tmnt/data_loading.py,sha256=vsAMyHGi3fuOFDmqo_zenNKOtVQiuqMHA-iPYWYpGKE,18873
4
- tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
- tmnt/estimator.py,sha256=htQ_JeUedEYWLPIBDbDhEL5deWtHiVNRKQN1528SybY,67751
3
+ tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
4
+ tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
5
+ tmnt/estimator.py,sha256=KnnvSNXm6cRL0GwDrGdgqqPX5ZubpCQ0WqcSXJDkUU4,68072
6
6
  tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
7
- tmnt/inference.py,sha256=da8qAnjTDTuWQfPEOQewOfgikqE00XT1xGMiO2mckI4,15679
8
- tmnt/modeling.py,sha256=O1V7ppU7J6pvESTvdEoV9BXbEF4Z-J1OHnRtszuagaA,29956
7
+ tmnt/inference.py,sha256=Iwc2_w7QrS1epiVEm_Ewx5sYFNNMDfvhMJETOgJqm0E,15783
8
+ tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
- tmnt/preprocess/vectorizer.py,sha256=RkdivqP76qAJDianV09lONad9NbfBVWLZgIbU_P1-zo,15796
11
+ tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
12
12
  tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
13
13
  tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
14
14
  tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
@@ -17,9 +17,10 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.56.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.56.dist-info/METADATA,sha256=jk7-JlrqxLTACr0LsMoLGXT0nq0VVQIkWFoFNqYlEPE,1436
22
- tmnt-0.7.56.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.56.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
- tmnt-0.7.56.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.56.dist-info/RECORD,,
20
+ tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
21
+ tmnt-0.7.58.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
22
+ tmnt-0.7.58.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.58.dist-info/METADATA,sha256=drdqhfVdpDs5LD_FMAMZjPRWw_TnNqFlGsh0QGtm8QE,1663
24
+ tmnt-0.7.58.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ tmnt-0.7.58.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
26
+ tmnt-0.7.58.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5