tmnt 0.7.57__py3-none-any.whl → 0.7.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/data_loading.py CHANGED
@@ -40,7 +40,8 @@ llm_catalog = {
40
40
  'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
41
41
  'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
42
42
  'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
43
- 'Alibaba-NLP/gte-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
43
+ 'Alibaba-NLP/gte-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
44
+ 'intfloat/multilingual-e5-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
44
45
  ## add more model options here ...
45
46
  }
46
47
 
tmnt/distribution.py CHANGED
@@ -15,24 +15,19 @@ import torch
15
15
  from scipy import special as sp
16
16
  import torch
17
17
  from typing import Callable, Literal, Optional, Tuple, TypeVar, Union
18
+ from tmnt.sparse.modeling import TopKEncoder
18
19
 
19
20
 
20
21
  __all__ = ['BaseDistribution', 'GaussianDistribution', 'GaussianUnitVarDistribution', 'LogisticGaussianDistribution',
21
22
  'VonMisesDistribution']
22
23
 
23
-
24
24
  class BaseDistribution(nn.Module):
25
25
 
26
- def __init__(self, enc_size, n_latent, device, on_simplex=False):
26
+ def __init__(self, enc_size, n_latent, device, on_simplex=True):
27
27
  super(BaseDistribution, self).__init__()
28
28
  self.n_latent = n_latent
29
29
  self.enc_size = enc_size
30
30
  self.device = device
31
- self.mu_encoder = nn.Linear(enc_size, n_latent).to(device)
32
- self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
33
- self.softmax = nn.Softmax(dim=1).to(device)
34
- self.softplus = nn.Softplus().to(device)
35
- self.on_simplex = on_simplex
36
31
 
37
32
  ## this is required by most priors
38
33
  def _get_gaussian_sample(self, mu, lv, batch_size):
@@ -47,11 +42,25 @@ class BaseDistribution(nn.Module):
47
42
 
48
43
  def get_mu_encoding(self, data, include_bn):
49
44
  raise NotImplemented
45
+
46
+ def freeze_pre_encoder(self) -> None:
47
+ raise NotImplemented
48
+
49
+ def unfreeze_pre_encoder(self) -> None:
50
+ raise NotImplemented
50
51
 
51
52
 
53
+ class SimpleDistribution(BaseDistribution):
54
+ def __init__(self, enc_size, n_latent, device, on_simplex=False):
55
+ super(SimpleDistribution, self).__init__(enc_size, n_latent, device, on_simplex=on_simplex)
56
+ self.mu_encoder = nn.Linear(enc_size, n_latent).to(device)
57
+ self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
58
+ self.softmax = nn.Softmax(dim=1).to(device)
59
+ self.softplus = nn.Softplus().to(device)
60
+ self.on_simplex = on_simplex
52
61
 
53
62
 
54
- class GaussianDistribution(BaseDistribution):
63
+ class GaussianDistribution(SimpleDistribution):
55
64
  """Gaussian latent distribution with diagnol co-variance.
56
65
 
57
66
  Parameters:
@@ -98,7 +107,7 @@ class GaussianDistribution(BaseDistribution):
98
107
 
99
108
 
100
109
 
101
- class GaussianUnitVarDistribution(BaseDistribution):
110
+ class GaussianUnitVarDistribution(SimpleDistribution):
102
111
  """Gaussian latent distribution with fixed unit variance.
103
112
 
104
113
  Parameters:
@@ -141,7 +150,7 @@ class GaussianUnitVarDistribution(BaseDistribution):
141
150
  return mu
142
151
 
143
152
 
144
- class LogisticGaussianDistribution(BaseDistribution):
153
+ class LogisticGaussianDistribution(SimpleDistribution):
145
154
  """Logistic normal/Gaussian latent distribution with specified prior
146
155
 
147
156
  Parameters:
@@ -198,7 +207,7 @@ class LogisticGaussianDistribution(BaseDistribution):
198
207
  return mu
199
208
 
200
209
 
201
- class VonMisesDistribution(BaseDistribution):
210
+ class VonMisesDistribution(SimpleDistribution):
202
211
 
203
212
  def __init__(self, enc_size, n_latent, kappa=100.0, dr=0.1, device='cpu'):
204
213
  super(VonMisesDistribution, self).__init__(enc_size, n_latent, device, on_simplex=False)
@@ -238,7 +247,7 @@ class VonMisesDistribution(BaseDistribution):
238
247
 
239
248
 
240
249
 
241
- class Projection(BaseDistribution):
250
+ class Projection(SimpleDistribution):
242
251
 
243
252
  def __init__(self, enc_size, n_latent, device='cpu'):
244
253
  super(Projection, self).__init__(enc_size, n_latent, device)
@@ -264,23 +273,7 @@ class Projection(BaseDistribution):
264
273
  return enc
265
274
 
266
275
 
267
-
268
- class TopK(nn.Module):
269
- def __init__(
270
- self, k: int, postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU()
271
- ):
272
- super().__init__()
273
- self.k = k
274
- self.postact_fn = postact_fn
275
-
276
- def forward(self, x: torch.Tensor) -> torch.Tensor:
277
- topk = torch.topk(x, k=self.k, dim=-1)
278
- values = self.postact_fn(topk.values)
279
- result = torch.zeros_like(x)
280
- result.scatter_(-1, topk.indices, values)
281
- return result
282
-
283
- class ConceptLogisticGaussianDistribution(nn.Module):
276
+ class ConceptLogisticGaussianDistribution(BaseDistribution):
284
277
  """Sparse concept encoding with Logistic normal/Gaussian latent distribution with specified prior
285
278
 
286
279
  Parameters:
@@ -289,19 +282,19 @@ class ConceptLogisticGaussianDistribution(nn.Module):
289
282
  dr (float): Dropout value for dropout applied post sample. optional (default = 0.2)
290
283
  alpha (float): Value the determines prior variance as 1/alpha - (2/n_latent) + 1/(n_latent^2)
291
284
  """
292
- def __init__(self, enc_size, n_latent, n_concepts=16000, k_sparsity=32, device='cpu', dr=0.1, alpha=1.0):
293
- super(ConceptLogisticGaussianDistribution, self).__init__()
285
+ def __init__(self, enc_size, n_latent, sparse_encoder: TopKEncoder, device='cpu', dr=0.1, alpha=1.0):
286
+ super(ConceptLogisticGaussianDistribution, self).__init__(enc_size, n_latent, device, on_simplex=True)
294
287
  self.n_latent = n_latent
295
288
  self.enc_size = enc_size
296
289
  self.device = device
297
- self.activation = TopK(k=k_sparsity)
298
- self.core_sparse = Sequential(nn.Linear(enc_size, n_concepts), self.activation).to(device)
299
- self.mu_encoder = Sequential(self.core_sparse, nn.Linear(n_concepts, n_latent)).to(device)
290
+ self.sparse_encoder = sparse_encoder.to(device)
291
+ self.n_concepts = sparse_encoder.get_dict_size()
292
+ self.sparse_to_mu = nn.Linear(self.n_concepts, n_latent).to(device)
293
+ self.sparse_bn = nn.BatchNorm1d(self.n_concepts, momentum=0.8, eps=0.0001).to(device)
300
294
  self.mu_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.0001).to(device)
301
295
  self.softmax = nn.Softmax(dim=1).to(device)
302
296
  self.on_simplex = True
303
297
  self.alpha = alpha
304
- self.n_concepts = n_concepts
305
298
 
306
299
  prior_var = 1 / self.alpha - (2.0 / n_latent) + 1 / (self.n_latent * self.n_latent)
307
300
  self.prior_var = torch.tensor([prior_var], device=device)
@@ -309,16 +302,18 @@ class ConceptLogisticGaussianDistribution(nn.Module):
309
302
 
310
303
  ## NOTE: the weights to model the log-variance are separate but the sparse encoder is shared
311
304
  ## between the lv_encoder and mu_encoder (above)
312
- self.lv_encoder = Sequential(self.core_sparse, nn.Linear(n_concepts, n_latent)).to(device)
305
+ self.sparse_to_lv = nn.Linear(self.n_concepts, n_latent).to(device)
313
306
  self.lv_bn = nn.BatchNorm1d(n_latent, momentum = 0.8, eps=0.001).to(device)
314
307
  self.post_sample_dr_o = nn.Dropout(dr)
315
308
 
316
309
 
317
- ## this is required by most priors
318
- def _get_gaussian_sample(self, mu, lv, batch_size):
319
- eps = Normal(torch.zeros(batch_size, self.n_latent),
320
- torch.ones(batch_size, self.n_latent)).sample().to(self.device)
321
- return (mu + torch.exp(0.5*lv).to(self.device) * eps)
310
+ def freeze_pre_encoder(self):
311
+ self.sparse_encoder.W_enc.requires_grad = False
312
+ self.sparse_encoder.b_enc.requires_grad = False
313
+
314
+ def unfreeze_pre_encoder(self):
315
+ self.sparse_encoder.W_enc.requires_grad = True
316
+ self.sparse_encoder.b_enc.requires_grad = True
322
317
 
323
318
  def _get_kl_term(self, mu, lv):
324
319
  posterior_var = torch.exp(lv)
@@ -331,14 +326,20 @@ class ConceptLogisticGaussianDistribution(nn.Module):
331
326
  def forward(self, data, batch_size):
332
327
  """Generate a sample according to the logistic Gaussian latent distribution given the encoder outputs
333
328
  """
334
- mu = self.mu_encoder(data)
329
+ _, sparse, _, _, _ = self.sparse_encoder(data)
330
+ #sparse_bn = self.sparse_bn(sparse)
331
+ mu = self.sparse_to_mu(sparse)
335
332
  mu_bn = self.mu_bn(mu)
336
- lv = self.lv_encoder(data)
333
+ lv = self.sparse_to_lv(sparse)
337
334
  lv_bn = self.lv_bn(lv)
338
335
  z_p = self._get_gaussian_sample(mu_bn, lv_bn, batch_size)
339
336
  KL = self._get_kl_term(mu, lv)
340
337
  z = self.post_sample_dr_o(z_p)
341
338
  return self.softmax(z), KL
339
+
340
+ def get_sparse_encoding(self, data):
341
+ _, sparse, _, _, _ = self.sparse_encoder(data)
342
+ return sparse
342
343
 
343
344
  def get_mu_encoding(self, data, include_bn=True, normalize=False):
344
345
  """Provide the distribution mean as the natural result of running the full encoder
@@ -348,7 +349,8 @@ class ConceptLogisticGaussianDistribution(nn.Module):
348
349
  Returns:
349
350
  encoding (:class:`mxnet.ndarray.NDArray`): Encoding vector representing unnormalized topic proportions
350
351
  """
351
- enc = self.mu_encoder(data)
352
+ _, sparse, _, _, _ = self.sparse_encoder(data)
353
+ enc = self.sparse_to_mu(sparse)
352
354
  if include_bn:
353
355
  enc = self.mu_bn(enc)
354
356
  mu = self.softmax(enc) if normalize else enc
tmnt/estimator.py CHANGED
@@ -943,6 +943,7 @@ class SeqBowEstimator(BaseEstimator):
943
943
  self._bow_matrix = None
944
944
  self.entropy_loss_coef = entropy_loss_coef
945
945
  self.pool_encoder = pool_encoder
946
+ self.freeze_pre_encoder_weights = False
946
947
 
947
948
 
948
949
  @classmethod
@@ -1013,6 +1014,9 @@ class SeqBowEstimator(BaseEstimator):
1013
1014
  pretrained_param_file = param_file,
1014
1015
  device = device)
1015
1016
 
1017
+ def freeze_pre_encoder(self):
1018
+ self.freeze_pre_encoder_weights = True
1019
+
1016
1020
 
1017
1021
  def _get_model_bias_initialize(self, train_data):
1018
1022
  model = self._get_model()
@@ -1030,6 +1034,7 @@ class SeqBowEstimator(BaseEstimator):
1030
1034
  entropy_loss_coef=self.entropy_loss_coef,
1031
1035
  dropout=self.classifier_dropout)
1032
1036
  return model
1037
+
1033
1038
 
1034
1039
  def _get_config(self):
1035
1040
  config = {}
@@ -1185,8 +1190,10 @@ class SeqBowEstimator(BaseEstimator):
1185
1190
  if self.model is None or not self.warm_start:
1186
1191
  self.model = self._get_model_bias_initialize(train_data)
1187
1192
 
1188
- model = self.model
1193
+ if self.freeze_pre_encoder_weights:
1194
+ self.model.freeze_pre_encoder()
1189
1195
 
1196
+ model = self.model
1190
1197
  accumulate = False
1191
1198
  v_res = None
1192
1199
 
@@ -1268,7 +1275,8 @@ class SeqBowEstimator(BaseEstimator):
1268
1275
  update_loss_details(total_ls_2, elbo_ls_2, red_ls_2, None)
1269
1276
 
1270
1277
  if not accumulate or (batch_id + 1) % accumulate == 0:
1271
- torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1278
+ if not self.freeze_pre_encoder_weights:
1279
+ torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
1272
1280
  optimizer.step()
1273
1281
  dec_optimizer.step()
1274
1282
  lr_scheduler.step()
tmnt/inference.py CHANGED
@@ -18,8 +18,9 @@ from tmnt.utils.recalibrate import recalibrate_scores
18
18
  from sklearn.datasets import load_svmlight_file
19
19
  from functools import partial
20
20
  from tmnt.data_loading import get_llm_tokenizer
21
-
22
21
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
22
+ from scipy.sparse import csr_matrix
23
+ from tmnt.distribution import ConceptLogisticGaussianDistribution
23
24
 
24
25
 
25
26
  MAX_DESIGN_MATRIX = 250000000
@@ -347,6 +348,9 @@ class MetricSeqVEDInferencer(SeqVEDInferencer):
347
348
 
348
349
 
349
350
 
351
+
352
+
353
+
350
354
 
351
355
 
352
356
 
tmnt/modeling.py CHANGED
@@ -45,6 +45,9 @@ class BaseVAE(nn.Module):
45
45
  t_npmi_mat = torch.Tensor(npmi_mat).to(self.device)
46
46
  self.npmi_with_diversity_loss = NPMILossWithDiversity(t_npmi_mat, device=self.device, npmi_lambda=npmi_lambda, npmi_scale=npmi_scale)
47
47
 
48
+ def freeze_pre_encoder(self):
49
+ pass
50
+
48
51
  def get_ordered_terms(self):
49
52
  """
50
53
  Returns the top K terms for each topic based on sensitivity analysis. Terms whose
@@ -360,7 +363,7 @@ class CoherenceRegularizer(nn.Module):
360
363
  class BaseSeqBowVED(BaseVAE):
361
364
  def __init__(self,
362
365
  llm,
363
- latent_dist,
366
+ latent_dist: BaseDistribution,
364
367
  num_classes=0,
365
368
  dropout=0.0,
366
369
  vocab_size=2000,
@@ -401,6 +404,11 @@ class BaseSeqBowVED(BaseVAE):
401
404
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
402
405
  else:
403
406
  return model_output.last_hidden_state[:,0,:]
407
+
408
+ def freeze_pre_encoder(self):
409
+ for p in self.llm.parameters():
410
+ p.requires_grad = False
411
+ self.latent_distribution.freeze_pre_encoder()
404
412
 
405
413
  def get_ordered_terms(self):
406
414
  """
@@ -447,6 +455,7 @@ class SeqBowVED(BaseSeqBowVED):
447
455
  self.classifier = torch.nn.Sequential()
448
456
  self.classifier.add_module("dr", nn.Dropout(self.dropout).to(self.device))
449
457
  self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes).to(self.device))
458
+
450
459
 
451
460
  def forward(self, input_ids, attention_mask, bow=None): # pylint: disable=arguments-differ
452
461
  llm_output = self.llm(input_ids, attention_mask)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: tmnt
3
- Version: 0.7.57
3
+ Version: 0.7.58
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -40,6 +40,7 @@ Dynamic: description
40
40
  Dynamic: description-content-type
41
41
  Dynamic: home-page
42
42
  Dynamic: license
43
+ Dynamic: license-file
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
45
46
  Dynamic: summary
@@ -1,11 +1,11 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
- tmnt/data_loading.py,sha256=zB3wIBXgl_UKjjRLQgPwCZOVTcjHK4YahxCbsLd70RY,19238
4
- tmnt/distribution.py,sha256=2YBfaGIiUJc-OjKaotnKmicSEdL4OAGBx3icacbePQ8,14868
5
- tmnt/estimator.py,sha256=qh-pCbmhhtGpRKKQv10ANyQakuoMYaVH87NM5UIxtyM,67777
3
+ tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
4
+ tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
5
+ tmnt/estimator.py,sha256=KnnvSNXm6cRL0GwDrGdgqqPX5ZubpCQ0WqcSXJDkUU4,68072
6
6
  tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
7
- tmnt/inference.py,sha256=da8qAnjTDTuWQfPEOQewOfgikqE00XT1xGMiO2mckI4,15679
8
- tmnt/modeling.py,sha256=QRnHbNFp85LKp5ILYsJqTeQ3BV0jLPCwKX1Eh-Ed3Dc,29975
7
+ tmnt/inference.py,sha256=Iwc2_w7QrS1epiVEm_Ewx5sYFNNMDfvhMJETOgJqm0E,15783
8
+ tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
@@ -18,9 +18,9 @@ tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,12
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
20
  tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
21
- tmnt-0.7.57.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
22
- tmnt-0.7.57.dist-info/METADATA,sha256=EDNrl4p3d9j2UXPwENrMAp0EgaRQuJCBGFvXdYoJTmI,1641
23
- tmnt-0.7.57.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
24
- tmnt-0.7.57.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
25
- tmnt-0.7.57.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
26
- tmnt-0.7.57.dist-info/RECORD,,
21
+ tmnt-0.7.58.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
22
+ tmnt-0.7.58.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.58.dist-info/METADATA,sha256=drdqhfVdpDs5LD_FMAMZjPRWw_TnNqFlGsh0QGtm8QE,1663
24
+ tmnt-0.7.58.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ tmnt-0.7.58.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
26
+ tmnt-0.7.58.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5