tmnt 0.7.56__py3-none-any.whl → 0.7.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/data_loading.py +14 -9
- tmnt/distribution.py +88 -3
- tmnt/estimator.py +5 -5
- tmnt/modeling.py +3 -4
- tmnt/preprocess/vectorizer.py +129 -11
- tmnt/utils/vocab.py +126 -0
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/METADATA +13 -3
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/RECORD +12 -11
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/WHEEL +1 -1
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/LICENSE +0 -0
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/NOTICE +0 -0
- {tmnt-0.7.56.dist-info → tmnt-0.7.57.dist-info}/top_level.txt +0 -0
tmnt/data_loading.py
CHANGED
@@ -25,9 +25,7 @@ from typing import List, Tuple, Dict, Optional, Union, NoReturn
|
|
25
25
|
|
26
26
|
import torch
|
27
27
|
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler
|
28
|
-
from
|
29
|
-
from torchtext.data.utils import get_tokenizer
|
30
|
-
from torchtext.vocab import build_vocab_from_iterator
|
28
|
+
from tmnt.utils.vocab import build_vocab
|
31
29
|
from transformers import DistilBertTokenizer, DistilBertModel, AutoTokenizer, AutoModel, DistilBertTokenizer, BertModel, DistilBertModel, OpenAIGPTModel
|
32
30
|
from sklearn.model_selection import StratifiedKFold
|
33
31
|
|
@@ -58,17 +56,18 @@ def get_llm_model(model_name):
|
|
58
56
|
tok_fn, model_fn = llm_catalog.get(model_name, ((AutoTokenizer.from_pretrained, AutoModel.from_pretrained)))
|
59
57
|
return model_fn(model_name, trust_remote_code=True)
|
60
58
|
|
61
|
-
def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len,
|
59
|
+
def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, bow_target_texts=None,
|
60
|
+
shuffle=False, device='cpu'):
|
62
61
|
label_pipeline = lambda x: label_map.get(x, 0)
|
63
62
|
text_pipeline = get_llm_tokenizer(llm_name)
|
64
63
|
|
65
64
|
def collate_batch(batch):
|
66
65
|
label_list, text_list, mask_list, bow_list = [], [], [], []
|
67
|
-
for (_label, _text) in batch:
|
66
|
+
for (_label, _text, _target_text) in batch:
|
68
67
|
label_list.append(label_pipeline(_label))
|
69
68
|
tokenized_result = text_pipeline(_text, return_tensors='pt', padding='max_length',
|
70
69
|
max_length=max_len, truncation=True)
|
71
|
-
bag_of_words,_ = bow_vectorizer.transform([
|
70
|
+
bag_of_words,_ = bow_vectorizer.transform([_target_text])
|
72
71
|
processed_text = tokenized_result['input_ids']
|
73
72
|
mask = tokenized_result['attention_mask']
|
74
73
|
mask_list.append(mask)
|
@@ -79,10 +78,16 @@ def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batc
|
|
79
78
|
mask_list = torch.vstack(mask_list)
|
80
79
|
bow_list = torch.vstack([ sparse_coo_to_tensor(bow_vec.tocoo()) for bow_vec in bow_list ])
|
81
80
|
return label_list.to(device), text_list.to(device), mask_list.to(device), bow_list.to(device)
|
82
|
-
|
81
|
+
if bow_target_texts is not None:
|
82
|
+
assert len(bow_target_texts) == len(data)
|
83
|
+
full_data = [ (label, txt, alt_text) for ((label, txt), alt_text) in zip(data, bow_target_texts)]
|
84
|
+
else:
|
85
|
+
full_data = [ (label, txt, txt) for (label, txt) in data]
|
86
|
+
return DataLoader(full_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
|
83
87
|
|
84
|
-
def get_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'):
|
85
|
-
return SingletonWrapperLoader(get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len,
|
88
|
+
def get_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, bow_target_texts=None, shuffle=False, device='cpu'):
|
89
|
+
return SingletonWrapperLoader(get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len,
|
90
|
+
bow_target_texts=bow_target_texts, shuffle=shuffle, device=device))
|
86
91
|
|
87
92
|
|
88
93
|
def get_llm_paired_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b,
|
tmnt/distribution.py
CHANGED
@@ -14,6 +14,7 @@ from torch.nn import Sequential
|
|
14
14
|
import torch
|
15
15
|
from scipy import special as sp
|
16
16
|
import torch
|
17
|
+
from typing import Callable, Literal, Optional, Tuple, TypeVar, Union
|
17
18
|
|
18
19
|
|
19
20
|
__all__ = ['BaseDistribution', 'GaussianDistribution', 'GaussianUnitVarDistribution', 'LogisticGaussianDistribution',
|
@@ -28,12 +29,10 @@ class BaseDistribution(nn.Module):
|
|
28
29
|
self.enc_size = enc_size
|
29
30
|
self.device = device
|
30
31
|
self.mu_encoder = nn.Linear(enc_size, n_latent).to(device)
|
31
|
-
#self.mu_encoder = Sequential(self.mu_proj, nn.Softplus().to(device))
|
32
32
|
self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
|
33
33
|
self.softmax = nn.Softmax(dim=1).to(device)
|
34
34
|
self.softplus = nn.Softplus().to(device)
|
35
35
|
self.on_simplex = on_simplex
|
36
|
-
#self.mu_bn.collect_params().setattr('grad_req', 'null')
|
37
36
|
|
38
37
|
## this is required by most priors
|
39
38
|
def _get_gaussian_sample(self, mu, lv, batch_size):
|
@@ -266,5 +265,91 @@ class Projection(BaseDistribution):
|
|
266
265
|
|
267
266
|
|
268
267
|
|
268
|
+
class TopK(nn.Module):
|
269
|
+
def __init__(
|
270
|
+
self, k: int, postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU()
|
271
|
+
):
|
272
|
+
super().__init__()
|
273
|
+
self.k = k
|
274
|
+
self.postact_fn = postact_fn
|
275
|
+
|
276
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
277
|
+
topk = torch.topk(x, k=self.k, dim=-1)
|
278
|
+
values = self.postact_fn(topk.values)
|
279
|
+
result = torch.zeros_like(x)
|
280
|
+
result.scatter_(-1, topk.indices, values)
|
281
|
+
return result
|
282
|
+
|
283
|
+
class ConceptLogisticGaussianDistribution(nn.Module):
|
284
|
+
"""Sparse concept encoding with Logistic normal/Gaussian latent distribution with specified prior
|
269
285
|
|
270
|
-
|
286
|
+
Parameters:
|
287
|
+
n_latent (int): Dimentionality of the latent distribution
|
288
|
+
device (device): Torch computational context (cpu or gpu[id])
|
289
|
+
dr (float): Dropout value for dropout applied post sample. optional (default = 0.2)
|
290
|
+
alpha (float): Value the determines prior variance as 1/alpha - (2/n_latent) + 1/(n_latent^2)
|
291
|
+
"""
|
292
|
+
def __init__(self, enc_size, n_latent, n_concepts=16000, k_sparsity=32, device='cpu', dr=0.1, alpha=1.0):
|
293
|
+
super(ConceptLogisticGaussianDistribution, self).__init__()
|
294
|
+
self.n_latent = n_latent
|
295
|
+
self.enc_size = enc_size
|
296
|
+
self.device = device
|
297
|
+
self.activation = TopK(k=k_sparsity)
|
298
|
+
self.core_sparse = Sequential(nn.Linear(enc_size, n_concepts), self.activation).to(device)
|
299
|
+
self.mu_encoder = Sequential(self.core_sparse, nn.Linear(n_concepts, n_latent)).to(device)
|
300
|
+
self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
|
301
|
+
self.softmax = nn.Softmax(dim=1).to(device)
|
302
|
+
self.on_simplex = True
|
303
|
+
self.alpha = alpha
|
304
|
+
self.n_concepts = n_concepts
|
305
|
+
|
306
|
+
prior_var = 1 / self.alpha - (2.0 / n_latent) + 1 / (self.n_latent * self.n_latent)
|
307
|
+
self.prior_var = torch.tensor([prior_var], device=device)
|
308
|
+
self.prior_logvar = torch.tensor([math.log(prior_var)], device=device)
|
309
|
+
|
310
|
+
## NOTE: the weights to model the log-variance are separate but the sparse encoder is shared
|
311
|
+
## between the lv_encoder and mu_encoder (above)
|
312
|
+
self.lv_encoder = Sequential(self.core_sparse, nn.Linear(n_concepts, n_latent)).to(device)
|
313
|
+
self.lv_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.001).to(device)
|
314
|
+
self.post_sample_dr_o = nn.Dropout(dr)
|
315
|
+
|
316
|
+
|
317
|
+
## this is required by most priors
|
318
|
+
def _get_gaussian_sample(self, mu, lv, batch_size):
|
319
|
+
eps = Normal(torch.zeros(batch_size, self.n_latent),
|
320
|
+
torch.ones(batch_size, self.n_latent)).sample().to(self.device)
|
321
|
+
return (mu + torch.exp(0.5*lv).to(self.device) * eps)
|
322
|
+
|
323
|
+
def _get_kl_term(self, mu, lv):
|
324
|
+
posterior_var = torch.exp(lv)
|
325
|
+
delta = mu
|
326
|
+
dt = torch.div(delta * delta, self.prior_var)
|
327
|
+
v_div = torch.div(posterior_var, self.prior_var)
|
328
|
+
lv_div = self.prior_logvar - lv
|
329
|
+
return (0.5 * (torch.sum((v_div + dt + lv_div), 1) - self.n_latent)).to(self.device)
|
330
|
+
|
331
|
+
def forward(self, data, batch_size):
|
332
|
+
"""Generate a sample according to the logistic Gaussian latent distribution given the encoder outputs
|
333
|
+
"""
|
334
|
+
mu = self.mu_encoder(data)
|
335
|
+
mu_bn = self.mu_bn(mu)
|
336
|
+
lv = self.lv_encoder(data)
|
337
|
+
lv_bn = self.lv_bn(lv)
|
338
|
+
z_p = self._get_gaussian_sample(mu_bn, lv_bn, batch_size)
|
339
|
+
KL = self._get_kl_term(mu, lv)
|
340
|
+
z = self.post_sample_dr_o(z_p)
|
341
|
+
return self.softmax(z), KL
|
342
|
+
|
343
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
344
|
+
"""Provide the distribution mean as the natural result of running the full encoder
|
345
|
+
|
346
|
+
Parameters:
|
347
|
+
data (:class:`mxnet.ndarray.NDArray`): Output of pre-latent encoding layers
|
348
|
+
Returns:
|
349
|
+
encoding (:class:`mxnet.ndarray.NDArray`): Encoding vector representing unnormalized topic proportions
|
350
|
+
"""
|
351
|
+
enc = self.mu_encoder(data)
|
352
|
+
if include_bn:
|
353
|
+
enc = self.mu_bn(enc)
|
354
|
+
mu = self.softmax(enc) if normalize else enc
|
355
|
+
return mu
|
tmnt/estimator.py
CHANGED
@@ -21,6 +21,7 @@ from tmnt.modeling import BowVAEModel, SeqBowVED, BaseVAE
|
|
21
21
|
from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
|
22
22
|
from tmnt.eval_npmi import EvaluateNPMI
|
23
23
|
from tmnt.distribution import LogisticGaussianDistribution, BaseDistribution, GaussianDistribution, VonMisesDistribution
|
24
|
+
from tmnt.utils.vocab import Vocab
|
24
25
|
|
25
26
|
## evaluation routines
|
26
27
|
from torcheval.metrics import MultilabelAUPRC, MulticlassAUPRC
|
@@ -38,7 +39,6 @@ import pickle
|
|
38
39
|
from typing import List, Tuple, Dict, Optional, Union, NoReturn
|
39
40
|
|
40
41
|
import torch
|
41
|
-
import torchtext
|
42
42
|
from torch.utils.data import Dataset, DataLoader
|
43
43
|
from tqdm import tqdm
|
44
44
|
|
@@ -249,7 +249,7 @@ class BaseBowEstimator(BaseEstimator):
|
|
249
249
|
device = device)
|
250
250
|
|
251
251
|
@classmethod
|
252
|
-
def from_config(cls, config: Union[str, dict], vocabulary: Union[str,
|
252
|
+
def from_config(cls, config: Union[str, dict], vocabulary: Union[str, Vocab],
|
253
253
|
n_labels: int = 0,
|
254
254
|
coherence_coefficient: float = 8.0,
|
255
255
|
coherence_via_encoder: bool = False,
|
@@ -948,7 +948,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
948
948
|
@classmethod
|
949
949
|
def from_config(cls,
|
950
950
|
config: Union[str, dict],
|
951
|
-
vocabulary:
|
951
|
+
vocabulary: Vocab,
|
952
952
|
log_interval: int = 1,
|
953
953
|
pretrained_param_file: Optional[str] = None,
|
954
954
|
n_labels: Optional[int] = None,
|
@@ -974,7 +974,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
974
974
|
raise Exception("Invalid Json Configuration File")
|
975
975
|
ldist_def = config['latent_distribution']
|
976
976
|
llm_model_name = config['llm_model_name']
|
977
|
-
model = torch.load(pretrained_param_file, map_location=device)
|
977
|
+
model = torch.load(pretrained_param_file, map_location=device, weights_only=False)
|
978
978
|
|
979
979
|
latent_distribution = model.latent_distribution
|
980
980
|
estimator = cls(llm_model_name = llm_model_name,
|
@@ -1006,7 +1006,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1006
1006
|
config_file = os.path.join(model_dir, 'model.config')
|
1007
1007
|
with open(config_file) as f:
|
1008
1008
|
config = json.loads(f.read())
|
1009
|
-
vocab = torch.load(vocab_file)
|
1009
|
+
vocab = torch.load(vocab_file, weights_only=False)
|
1010
1010
|
return cls.from_config(config,
|
1011
1011
|
vocabulary = vocab,
|
1012
1012
|
log_interval = log_interval,
|
tmnt/modeling.py
CHANGED
@@ -56,7 +56,6 @@ class BaseVAE(nn.Module):
|
|
56
56
|
sorted_j = jacobian.argsort(dim=0, descending=True)
|
57
57
|
return sorted_j.cpu().numpy()
|
58
58
|
|
59
|
-
|
60
59
|
def get_topic_vectors(self):
|
61
60
|
"""
|
62
61
|
Returns unnormalized topic vectors
|
@@ -126,7 +125,8 @@ class BowVAEModel(BaseVAE):
|
|
126
125
|
|
127
126
|
def _init_weights(self, module):
|
128
127
|
if isinstance(module, torch.nn.Linear):
|
129
|
-
torch.nn.init.
|
128
|
+
torch.nn.init.kaiming_uniform_(module.weight.data)
|
129
|
+
#torch.nn.init.xavier_uniform_(module.weight.data)
|
130
130
|
|
131
131
|
|
132
132
|
def _get_encoder(self, dims, dr=0.1):
|
@@ -462,9 +462,8 @@ class SeqBowVED(BaseSeqBowVED):
|
|
462
462
|
classifier_outputs = self.classifier(z_mu)
|
463
463
|
else:
|
464
464
|
classifier_outputs = None
|
465
|
-
redundancy_loss = entropy_loss
|
466
465
|
ii_loss = self.add_npmi_and_diversity_loss(elbo)
|
467
|
-
redundancy_loss =
|
466
|
+
redundancy_loss = ii_loss #self.get_redundancy_penalty()
|
468
467
|
return ii_loss, rec_loss, KL_loss, redundancy_loss, classifier_outputs
|
469
468
|
|
470
469
|
|
tmnt/preprocess/vectorizer.py
CHANGED
@@ -6,17 +6,12 @@ Copyright (c) 2019-2021 The MITRE Corporation.
|
|
6
6
|
import io
|
7
7
|
import os
|
8
8
|
import json
|
9
|
-
import
|
10
|
-
from torchtext.vocab import vocab as build_vocab
|
9
|
+
from tmnt.utils.vocab import Vocab, build_vocab
|
11
10
|
import glob
|
12
11
|
from multiprocessing import Pool, cpu_count
|
13
12
|
from mantichora import mantichora
|
14
13
|
from atpbar import atpbar
|
15
14
|
import collections
|
16
|
-
import threading
|
17
|
-
import logging
|
18
|
-
import threading
|
19
|
-
import scipy
|
20
15
|
import scipy.sparse as sp
|
21
16
|
import numpy as np
|
22
17
|
from queue import Queue
|
@@ -25,9 +20,14 @@ from sklearn.datasets import dump_svmlight_file
|
|
25
20
|
from tmnt.preprocess import BasicTokenizer
|
26
21
|
from typing import List, Dict, Optional, Any, Tuple
|
27
22
|
from collections import OrderedDict
|
23
|
+
from sklearn.utils import check_array
|
24
|
+
from sklearn.preprocessing import normalize
|
25
|
+
from sklearn.feature_extraction.text import TfidfTransformer
|
26
|
+
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted
|
27
|
+
from sklearn.feature_extraction._stop_words import ENGLISH_STOP_WORDS
|
28
28
|
|
29
|
-
__all__ = ['TMNTVectorizer']
|
30
29
|
|
30
|
+
__all__ = ['TMNTVectorizer', 'CTFIDFVectorizer']
|
31
31
|
|
32
32
|
class TMNTVectorizer(object):
|
33
33
|
|
@@ -57,10 +57,12 @@ class TMNTVectorizer(object):
|
|
57
57
|
def __init__(self, text_key: str = 'body', label_key: Optional[str] = None, min_doc_size: int = 1,
|
58
58
|
label_remap: Optional[Dict[str,str]] = None,
|
59
59
|
json_out_dir: Optional[str] = None, vocab_size: int = 2000, file_pat: str = '*.json',
|
60
|
-
encoding: str = 'utf-8', initial_vocabulary: Optional[
|
60
|
+
encoding: str = 'utf-8', initial_vocabulary: Optional[Vocab] = None,
|
61
61
|
additional_feature_keys: List[str] = None, stop_word_file: str = None,
|
62
62
|
split_char: str = ',',
|
63
63
|
max_ws_tokens: int = -1,
|
64
|
+
source_key: Optional[str] = None,
|
65
|
+
source_json: Optional[str] = None,
|
64
66
|
count_vectorizer_kwargs: Dict[str, Any] = {'max_df':0.95, 'min_df':0.0, 'stop_words':'english'}):
|
65
67
|
self.encoding = encoding
|
66
68
|
self.max_ws_tokens = max_ws_tokens
|
@@ -78,12 +80,53 @@ class TMNTVectorizer(object):
|
|
78
80
|
self.cv_kwargs = self._update_count_vectorizer_args(count_vectorizer_kwargs, stop_word_file)
|
79
81
|
if not 'token_pattern' in self.cv_kwargs:
|
80
82
|
self.cv_kwargs['token_pattern'] = r'\b[A-Za-z][A-Za-z]+\b'
|
83
|
+
if source_key and source_json:
|
84
|
+
source_terms = self._get_source_specific_terms(source_json, 10, text_key, source_key,
|
85
|
+
{'token_pattern': self.cv_kwargs['token_pattern'],
|
86
|
+
'stop_words': self.cv_kwargs['stop_words'],
|
87
|
+
'max_df': 1.0, 'min_df':0.0})
|
88
|
+
stop_words = set(source_terms)
|
89
|
+
stop_words.update(set(ENGLISH_STOP_WORDS))
|
90
|
+
self.cv_kwargs['stop_words'] = frozenset(stop_words)
|
81
91
|
self.vectorizer = CountVectorizer(max_features=self.vocab_size,
|
82
92
|
vocabulary=(initial_vocabulary.get_itos() if initial_vocabulary else None),
|
83
93
|
**self.cv_kwargs)
|
84
94
|
self.label_map = {}
|
85
95
|
|
86
96
|
|
97
|
+
def _get_source_specific_terms(self, json_file, k: int, text_key: str, source_key: str, cv_kwargs):
|
98
|
+
by_source = {}
|
99
|
+
with io.open(json_file) as fp:
|
100
|
+
for l in fp:
|
101
|
+
js = json.loads(l)
|
102
|
+
txt = js[text_key]
|
103
|
+
src = js[source_key]
|
104
|
+
if src not in by_source:
|
105
|
+
by_source[src] = []
|
106
|
+
by_source[src].append(txt)
|
107
|
+
docs_by_source = [''.join(txts) for txts in by_source.values()]
|
108
|
+
print(cv_kwargs)
|
109
|
+
count_vectorizer = CountVectorizer(**cv_kwargs)
|
110
|
+
count = count_vectorizer.fit_transform(docs_by_source)
|
111
|
+
ctfidf = CTFIDFVectorizer().fit_transform(count)
|
112
|
+
tok_to_idx = list(count_vectorizer.vocabulary_.items())
|
113
|
+
tok_to_idx.sort(key = lambda x: x[1])
|
114
|
+
ordered_vocab = OrderedDict([ (k,1) for (k,_) in tok_to_idx ])
|
115
|
+
ovocab = build_vocab(ordered_vocab)
|
116
|
+
per_source_tokens = []
|
117
|
+
for i in range(ctfidf.shape[0]):
|
118
|
+
ts = ctfidf[i].toarray().squeeze()
|
119
|
+
per_source_tokens.append(ovocab.lookup_tokens((-ts).argsort()[:k]))
|
120
|
+
final_tokens_intersect = set(per_source_tokens[0])
|
121
|
+
final_tokens_union = set(per_source_tokens[0])
|
122
|
+
for src_tokens in per_source_tokens:
|
123
|
+
final_tokens_intersect.intersection_update(src_tokens)
|
124
|
+
final_tokens_union.update(src_tokens)
|
125
|
+
res = final_tokens_union - final_tokens_intersect
|
126
|
+
print("Removed terms = {}".format(res))
|
127
|
+
return final_tokens_union - final_tokens_intersect
|
128
|
+
|
129
|
+
|
87
130
|
|
88
131
|
def _update_count_vectorizer_args(self, cv_kwargs: Dict[str, Any], stop_word_file: str) -> Dict[str, Any]:
|
89
132
|
if stop_word_file:
|
@@ -113,11 +156,11 @@ class TMNTVectorizer(object):
|
|
113
156
|
return list(set(wds))
|
114
157
|
|
115
158
|
|
116
|
-
def get_vocab(self) ->
|
117
|
-
"""Returns the
|
159
|
+
def get_vocab(self) -> Vocab:
|
160
|
+
"""Returns the vocabulary associated with the vectorizer
|
118
161
|
|
119
162
|
Returns:
|
120
|
-
|
163
|
+
vocabulary
|
121
164
|
"""
|
122
165
|
if self.vocab is not None:
|
123
166
|
return self.vocab
|
@@ -375,3 +418,78 @@ class TMNTVectorizer(object):
|
|
375
418
|
y = self._get_ys_dir(json_dir)
|
376
419
|
return X, y
|
377
420
|
|
421
|
+
|
422
|
+
|
423
|
+
class CTFIDFVectorizer(TfidfTransformer):
|
424
|
+
def __init__(self, *args, **kwargs):
|
425
|
+
super(CTFIDFVectorizer, self).__init__(*args, **kwargs)
|
426
|
+
self._idf_diag = None
|
427
|
+
|
428
|
+
def fit(self, X: sp.csr_matrix):
|
429
|
+
"""Learn the idf vector (global term weights)
|
430
|
+
|
431
|
+
Parameters
|
432
|
+
----------
|
433
|
+
X : sparse matrix of shape n_samples, n_features)
|
434
|
+
A matrix of term/token counts.
|
435
|
+
|
436
|
+
"""
|
437
|
+
|
438
|
+
# Prepare input
|
439
|
+
X = check_array(X, accept_sparse=('csr', 'csc'))
|
440
|
+
if not sp.issparse(X):
|
441
|
+
X = sp.csr_matrix(X)
|
442
|
+
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
|
443
|
+
|
444
|
+
# Calculate IDF scores
|
445
|
+
_, n_features = X.shape
|
446
|
+
df = np.squeeze(np.asarray(X.sum(axis=0)))
|
447
|
+
avg_nr_samples = int(X.sum(axis=1).mean())
|
448
|
+
idf = np.log(avg_nr_samples / df)
|
449
|
+
self._idf_diag = sp.diags(idf, offsets=0,
|
450
|
+
shape=(n_features, n_features),
|
451
|
+
format='csr',
|
452
|
+
dtype=dtype)
|
453
|
+
setattr(self, 'idf_', True)
|
454
|
+
return self
|
455
|
+
|
456
|
+
def transform(self, X: sp.csr_matrix, copy=True) -> sp.csr_matrix:
|
457
|
+
"""Transform a count-based matrix to c-TF-IDF
|
458
|
+
|
459
|
+
Parameters
|
460
|
+
----------
|
461
|
+
X : sparse matrix of (n_samples, n_features)
|
462
|
+
a matrix of term/token counts
|
463
|
+
|
464
|
+
Returns
|
465
|
+
-------
|
466
|
+
vectors : sparse matrix of shape (n_samples, n_features)
|
467
|
+
|
468
|
+
"""
|
469
|
+
|
470
|
+
# Prepare input
|
471
|
+
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES, copy=copy)
|
472
|
+
if not sp.issparse(X):
|
473
|
+
X = sp.csr_matrix(X, dtype=np.float64)
|
474
|
+
|
475
|
+
_, n_features = X.shape
|
476
|
+
|
477
|
+
# idf_ being a property, the automatic attributes detection
|
478
|
+
# does not work as usual and we need to specify the attribute
|
479
|
+
# name:
|
480
|
+
check_is_fitted(self, attributes=["idf_"],
|
481
|
+
msg='idf vector is not fitted')
|
482
|
+
|
483
|
+
# Check if expected nr features is found
|
484
|
+
expected_n_features = self._idf_diag.shape[0]
|
485
|
+
if n_features != expected_n_features:
|
486
|
+
raise ValueError("Input has n_features=%d while the model"
|
487
|
+
" has been trained with n_features=%d" % (
|
488
|
+
n_features, expected_n_features))
|
489
|
+
|
490
|
+
X = X * self._idf_diag
|
491
|
+
|
492
|
+
if self.norm:
|
493
|
+
X = normalize(X, axis=1, norm='l1', copy=False)
|
494
|
+
|
495
|
+
return X
|
tmnt/utils/vocab.py
ADDED
@@ -0,0 +1,126 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from typing import Dict, List, Optional, Iterable, OrderedDict
|
4
|
+
from collections import Counter
|
5
|
+
|
6
|
+
class Vocab(nn.Module):
|
7
|
+
r"""Creates a vocab object which maps tokens to indices.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
vocab (torch.classes.torchtext.Vocab or torchtext._torchtext.Vocab): a cpp vocab object.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, stoi: Dict):
|
14
|
+
super(Vocab, self).__init__()
|
15
|
+
self.stoi = stoi
|
16
|
+
self.itos = list(stoi.keys())
|
17
|
+
|
18
|
+
def forward(self, tokens: List[str]) -> List[int]:
|
19
|
+
r"""Calls the `lookup_indices` method
|
20
|
+
|
21
|
+
Args:
|
22
|
+
tokens: a list of tokens used to lookup their corresponding `indices`.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
The indices associated with a list of `tokens`.
|
26
|
+
"""
|
27
|
+
return [self.stoi[t] for t in tokens]
|
28
|
+
|
29
|
+
def __len__(self) -> int:
|
30
|
+
r"""
|
31
|
+
Returns:
|
32
|
+
The length of the vocab.
|
33
|
+
"""
|
34
|
+
return len(self.stoi)
|
35
|
+
|
36
|
+
def __contains__(self, token: str) -> bool:
|
37
|
+
r"""
|
38
|
+
Args:
|
39
|
+
token: The token for which to check the membership.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Whether the token is member of vocab or not.
|
43
|
+
"""
|
44
|
+
return self.stoi.__contains__(token)
|
45
|
+
|
46
|
+
def __getitem__(self, token: str) -> int:
|
47
|
+
r"""
|
48
|
+
Args:
|
49
|
+
token: The token used to lookup the corresponding index.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The index corresponding to the associated token.
|
53
|
+
"""
|
54
|
+
return self.stoi[token]
|
55
|
+
|
56
|
+
def insert_token(self, token: str, index: int) -> None:
|
57
|
+
r"""
|
58
|
+
Args:
|
59
|
+
token: The token used to lookup the corresponding index.
|
60
|
+
index: The index corresponding to the associated token.
|
61
|
+
Raises:
|
62
|
+
RuntimeError: If `index` is not in range [0, Vocab.size()] or if `token` already exists in the vocab.
|
63
|
+
"""
|
64
|
+
if not token in self.stoi:
|
65
|
+
self.stoi[token] = index
|
66
|
+
self.itos[index] = token
|
67
|
+
|
68
|
+
def lookup_token(self, index: int) -> str:
|
69
|
+
r"""
|
70
|
+
Args:
|
71
|
+
index: The index corresponding to the associated token.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
token: The token used to lookup the corresponding index.
|
75
|
+
|
76
|
+
Raises:
|
77
|
+
RuntimeError: If `index` not in range [0, itos.size()).
|
78
|
+
"""
|
79
|
+
return self.itos[index]
|
80
|
+
|
81
|
+
def lookup_tokens(self, indices: List[int]) -> List[str]:
|
82
|
+
r"""
|
83
|
+
Args:
|
84
|
+
indices: The `indices` used to lookup their corresponding`tokens`.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The `tokens` associated with `indices`.
|
88
|
+
|
89
|
+
Raises:
|
90
|
+
RuntimeError: If an index within `indices` is not int range [0, itos.size()).
|
91
|
+
"""
|
92
|
+
return [ self.itos[i] for i in indices]
|
93
|
+
|
94
|
+
def lookup_indices(self, tokens: List[str]) -> List[int]:
|
95
|
+
r"""
|
96
|
+
Args:
|
97
|
+
tokens: the tokens used to lookup their corresponding `indices`.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
The 'indices` associated with `tokens`.
|
101
|
+
"""
|
102
|
+
return [ self.stoi[t] for t in tokens ]
|
103
|
+
|
104
|
+
def get_stoi(self) -> Dict[str, int]:
|
105
|
+
r"""
|
106
|
+
Returns:
|
107
|
+
Dictionary mapping tokens to indices.
|
108
|
+
"""
|
109
|
+
return self.stoi
|
110
|
+
|
111
|
+
def get_itos(self) -> List[str]:
|
112
|
+
r"""
|
113
|
+
Returns:
|
114
|
+
List mapping indices to tokens.
|
115
|
+
"""
|
116
|
+
return self.itos
|
117
|
+
|
118
|
+
|
119
|
+
|
120
|
+
def build_vocab(
|
121
|
+
odict: OrderedDict
|
122
|
+
) -> Vocab:
|
123
|
+
"""
|
124
|
+
"""
|
125
|
+
dict_by_position = dict(zip(odict.keys(), range(0,len(odict))))
|
126
|
+
return Vocab(dict_by_position)
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: tmnt
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.57
|
4
4
|
Summary: Topic modeling neural toolkit
|
5
5
|
Home-page: https://github.com/mitre/tmnt.git
|
6
6
|
Author: The MITRE Corporation
|
@@ -14,6 +14,7 @@ Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
15
15
|
License-File: NOTICE
|
16
16
|
Requires-Dist: optuna
|
17
|
+
Requires-Dist: datasets
|
17
18
|
Requires-Dist: mantichora>=0.9.5
|
18
19
|
Requires-Dist: transformers[torch]
|
19
20
|
Requires-Dist: torcheval
|
@@ -32,7 +33,16 @@ Requires-Dist: numba
|
|
32
33
|
Requires-Dist: scipy==1.12.0
|
33
34
|
Requires-Dist: tabulate>=0.8.7
|
34
35
|
Requires-Dist: torch>=2.1.2
|
35
|
-
|
36
|
+
Dynamic: author
|
37
|
+
Dynamic: author-email
|
38
|
+
Dynamic: classifier
|
39
|
+
Dynamic: description
|
40
|
+
Dynamic: description-content-type
|
41
|
+
Dynamic: home-page
|
42
|
+
Dynamic: license
|
43
|
+
Dynamic: requires-dist
|
44
|
+
Dynamic: requires-python
|
45
|
+
Dynamic: summary
|
36
46
|
|
37
47
|
The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
|
38
48
|
topic models as neural network-based variational auto-encoders.
|
@@ -1,14 +1,14 @@
|
|
1
1
|
tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
|
-
tmnt/data_loading.py,sha256=
|
4
|
-
tmnt/distribution.py,sha256=
|
5
|
-
tmnt/estimator.py,sha256=
|
3
|
+
tmnt/data_loading.py,sha256=zB3wIBXgl_UKjjRLQgPwCZOVTcjHK4YahxCbsLd70RY,19238
|
4
|
+
tmnt/distribution.py,sha256=2YBfaGIiUJc-OjKaotnKmicSEdL4OAGBx3icacbePQ8,14868
|
5
|
+
tmnt/estimator.py,sha256=qh-pCbmhhtGpRKKQv10ANyQakuoMYaVH87NM5UIxtyM,67777
|
6
6
|
tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
|
7
7
|
tmnt/inference.py,sha256=da8qAnjTDTuWQfPEOQewOfgikqE00XT1xGMiO2mckI4,15679
|
8
|
-
tmnt/modeling.py,sha256=
|
8
|
+
tmnt/modeling.py,sha256=QRnHbNFp85LKp5ILYsJqTeQ3BV0jLPCwKX1Eh-Ed3Dc,29975
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
|
-
tmnt/preprocess/vectorizer.py,sha256=
|
11
|
+
tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
|
12
12
|
tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
|
13
13
|
tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
|
14
14
|
tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
|
@@ -17,9 +17,10 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
|
21
|
+
tmnt-0.7.57.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
22
|
+
tmnt-0.7.57.dist-info/METADATA,sha256=EDNrl4p3d9j2UXPwENrMAp0EgaRQuJCBGFvXdYoJTmI,1641
|
23
|
+
tmnt-0.7.57.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
24
|
+
tmnt-0.7.57.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
25
|
+
tmnt-0.7.57.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
26
|
+
tmnt-0.7.57.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|