tmnt 0.7.60__py3-none-any.whl → 0.7.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
1
  # coding: utf-8
2
2
 
3
- from .distribution import *
3
+
4
4
  from .preprocess import *
5
5
  from .sparse import *
6
6
  from .utils import *
7
+ from .distribution import *
7
8
 
8
- __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__ + sparse.__all__
9
+ __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__ # + sparse.__all__
tmnt/estimator.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
15
  import scipy.sparse as sp
16
16
  import json
17
17
 
18
- from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
18
+ from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score
19
19
  from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
20
20
  from tmnt.modeling import BowVAEModel, SeqBowVED, BaseVAE
21
21
  from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
@@ -29,18 +29,15 @@ from torcheval.metrics import MultilabelAUPRC, MulticlassAUPRC
29
29
  ## huggingface specifics
30
30
  from transformers.trainer_pt_utils import get_parameter_names
31
31
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
32
- from transformers.optimization import AdamW, get_scheduler
32
+ from transformers.optimization import get_scheduler
33
33
 
34
34
  ## model selection
35
35
  import optuna
36
36
 
37
- from itertools import cycle
38
37
  import pickle
39
38
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
40
39
 
41
40
  import torch
42
- from torch.utils.data import Dataset, DataLoader
43
- from tqdm import tqdm
44
41
 
45
42
  MAX_DESIGN_MATRIX = 250000000
46
43
 
tmnt/sparse/config.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
 
3
+ __all__ = ['get_default_cfg']
4
+
3
5
  def get_default_cfg():
4
6
  default_cfg = {
5
7
  "seed": 49,
tmnt/sparse/estimator.py CHANGED
@@ -8,6 +8,8 @@ import io, json
8
8
  from tmnt.sparse.modeling import BaseAutoencoder
9
9
  from typing import List
10
10
 
11
+ __all__ = ['ActivationsStore', 'build_activation_store', 'build_activation_store_batching', 'train_sparse_encoder_decoder']
12
+
11
13
  class ActivationsStore:
12
14
  def __init__(
13
15
  self,
tmnt/sparse/inference.py CHANGED
@@ -11,6 +11,7 @@ from datasets.arrow_writer import ArrowWriter
11
11
  from tmnt.inference import SeqVEDInferencer
12
12
  import io, json
13
13
 
14
+ __all__ = ['batch_process_to_arrow']
14
15
 
15
16
  def csr_to_indices_data(csr_mat):
16
17
  return [ (csr_mat.getrow(ri).indices, csr_mat.getrow(ri).data) for ri in range(csr_mat.shape[0]) ]
tmnt/sparse/modeling.py CHANGED
@@ -3,6 +3,9 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  import torch.autograd as autograd
5
5
 
6
+
7
+ __all__ = ['BatchTopKEncoder', 'BatchTopKSAE', 'TopKEncoder', 'TopKSAE', 'VanillaEncoder', 'VanillaSAE', 'JumpReLUEncoder', 'JumpReLUSAE']
8
+
6
9
  class BaseEncoder(nn.Module):
7
10
 
8
11
  def __init__(self, cfg):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tmnt
3
- Version: 0.7.60
3
+ Version: 0.7.61
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -48,7 +48,7 @@ Dynamic: summary
48
48
  The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
49
49
  topic models as neural network-based variational auto-encoders.
50
50
 
51
- Current stable version is: 0.7.60
51
+ Current stable version is: 0.7.61
52
52
 
53
53
  Documentation can be found here: https://tmnt.readthedocs.io/en/stable/
54
54
 
@@ -1,8 +1,8 @@
1
- tmnt/__init__.py,sha256=s7YqLj32HKhIYO1QbD0zms8rDlTrleJM8LRjKt8bDPk,200
1
+ tmnt/__init__.py,sha256=3W6sRKacvIMQuxGZpscjU52B4sDcasEA1RCXJVHaK-Q,203
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
4
4
  tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
5
- tmnt/estimator.py,sha256=KnnvSNXm6cRL0GwDrGdgqqPX5ZubpCQ0WqcSXJDkUU4,68072
5
+ tmnt/estimator.py,sha256=SJW9koDHb_lTwmqXm4Ilgga7MMMmRGLd3ylEFYu3bZg,67933
6
6
  tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
7
7
  tmnt/inference.py,sha256=Iwc2_w7QrS1epiVEm_Ewx5sYFNNMDfvhMJETOgJqm0E,15783
8
8
  tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
@@ -10,10 +10,10 @@ tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,1
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
12
12
  tmnt/sparse/__init__.py,sha256=BEhOm_o0UrVUKTG3rSiBJzE7qQQL9HRSZ1MHCA2GJu8,249
13
- tmnt/sparse/config.py,sha256=gfJ1BAP3zzMKKUJExrs8D0hHB7XOVvqXPpmx0ECAPIE,796
14
- tmnt/sparse/estimator.py,sha256=SOEPkQo7T2RuBJLYRAk65MhoYjgYtNCAzlvhqq10c5E,4596
15
- tmnt/sparse/inference.py,sha256=etOuXTxh8bKc7EoohZLYYufFAD51aEpq8mGj1vjULbg,2813
16
- tmnt/sparse/modeling.py,sha256=IejHbRXyj5WsEthkVbp3vYF1wTAAsihbneaPuCgAfGA,13612
13
+ tmnt/sparse/config.py,sha256=DzqzmclxbUmmBZHS7poMicsvgOZCMXTXWEY5bTpPrRI,827
14
+ tmnt/sparse/estimator.py,sha256=PSqIopCZG7WO5MXYcN-CAs7R7M36jZXjQeUp0I5kG9Q,4721
15
+ tmnt/sparse/inference.py,sha256=KiSUjo_lO14qkUzzo9HuOWZ5S5oJRzoSsSa2dOg54UY,2850
16
+ tmnt/sparse/modeling.py,sha256=hy2nkI2tq3Bpmj7OfXr_zwHRajWBwsqz7T3wHTuEbyE,13753
17
17
  tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
18
18
  tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
19
19
  tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
@@ -23,9 +23,9 @@ tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,12
23
23
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
24
24
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
25
25
  tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
26
- tmnt-0.7.60.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
27
- tmnt-0.7.60.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
28
- tmnt-0.7.60.dist-info/METADATA,sha256=DbmpuasEyoW6FNHmu-YfNsgq5J1lV0MjWPe04CCq8Fk,1663
29
- tmnt-0.7.60.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
30
- tmnt-0.7.60.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
31
- tmnt-0.7.60.dist-info/RECORD,,
26
+ tmnt-0.7.61.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
27
+ tmnt-0.7.61.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
28
+ tmnt-0.7.61.dist-info/METADATA,sha256=hyptC3YXersP-7oR89NgKU4UksTrvsbcC5Ml38WSYG0,1663
29
+ tmnt-0.7.61.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
30
+ tmnt-0.7.61.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
31
+ tmnt-0.7.61.dist-info/RECORD,,
File without changes