tmnt 0.7.54b20240817__py3-none-any.whl → 0.7.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/data_loading.py +5 -4
- tmnt/estimator.py +6 -29
- tmnt/eval_npmi.py +19 -0
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/METADATA +1 -1
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/RECORD +9 -9
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/WHEEL +1 -1
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/LICENSE +0 -0
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/NOTICE +0 -0
- {tmnt-0.7.54b20240817.dist-info → tmnt-0.7.56.dist-info}/top_level.txt +0 -0
tmnt/data_loading.py
CHANGED
@@ -41,13 +41,14 @@ llm_catalog = {
|
|
41
41
|
'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
42
42
|
'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
43
43
|
'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
44
|
-
'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
|
45
|
-
|
44
|
+
'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
|
45
|
+
'Alibaba-NLP/gte-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained)
|
46
|
+
## add more model options here ...
|
46
47
|
}
|
47
48
|
|
48
49
|
def get_llm(model_name):
|
49
50
|
tok_fn, model_fn = llm_catalog.get(model_name, ((AutoTokenizer.from_pretrained, AutoModel.from_pretrained)))
|
50
|
-
return tok_fn(model_name), model_fn(model_name)
|
51
|
+
return tok_fn(model_name), model_fn(model_name, trust_remote_code=True)
|
51
52
|
|
52
53
|
def get_llm_tokenizer(model_name):
|
53
54
|
tok_fn, model_fn = llm_catalog.get(model_name, ((AutoTokenizer.from_pretrained, AutoModel.from_pretrained)))
|
@@ -55,7 +56,7 @@ def get_llm_tokenizer(model_name):
|
|
55
56
|
|
56
57
|
def get_llm_model(model_name):
|
57
58
|
tok_fn, model_fn = llm_catalog.get(model_name, ((AutoTokenizer.from_pretrained, AutoModel.from_pretrained)))
|
58
|
-
return model_fn(model_name)
|
59
|
+
return model_fn(model_name, trust_remote_code=True)
|
59
60
|
|
60
61
|
def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'):
|
61
62
|
label_pipeline = lambda x: label_map.get(x, 0)
|
tmnt/estimator.py
CHANGED
@@ -285,16 +285,11 @@ class BaseBowEstimator(BaseEstimator):
|
|
285
285
|
logging.error("File {} does not appear to be a valid vocabulary file".format(vocabulary))
|
286
286
|
raise Exception("Invalid Json Configuration File")
|
287
287
|
vocabulary = torchtext.vocab.vocab(voc_js)
|
288
|
-
|
289
|
-
if
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
emb_size = config['embedding'].get('size')
|
294
|
-
if not emb_size:
|
295
|
-
emb_size = config['derived_info'].get('embedding_size')
|
296
|
-
if not emb_size:
|
297
|
-
raise Exception("Embedding size must be provided as the 'size' attribute of 'embedding' or as 'derived_info.embedding_size'")
|
288
|
+
emb_size = config['embedding'].get('size')
|
289
|
+
if not emb_size:
|
290
|
+
emb_size = config['derived_info'].get('embedding_size')
|
291
|
+
if not emb_size:
|
292
|
+
raise Exception("Embedding size must be provided as the 'size' attribute of 'embedding' or as 'derived_info.embedding_size'")
|
298
293
|
gamma = config.get('gamma', 1.0)
|
299
294
|
multilabel = config.get('multilabel', False)
|
300
295
|
lr = config['lr']
|
@@ -781,12 +776,6 @@ class BowMetricEstimator(BowEstimator):
|
|
781
776
|
def _get_model(self, bow_size=-1):
|
782
777
|
if self.embedding_source != 'random':
|
783
778
|
e_type, e_name = tuple(self.embedding_source.split(':'))
|
784
|
-
#pt_embedding = nlp.embedding.create(e_type, source=e_name)
|
785
|
-
#self.vocabulary.set_embedding(pt_embedding)
|
786
|
-
#emb_size = len(self.vocabulary.embedding.idx_to_vec[0])
|
787
|
-
#for word in self.vocabulary.embedding._idx_to_token:
|
788
|
-
# if (self.vocabulary.embedding[word] == mx.nd.zeros(emb_size)).sum() == emb_size:
|
789
|
-
# self.vocabulary.embedding[word] = mx.nd.random.normal(0, 0.1, emb_size)
|
790
779
|
else:
|
791
780
|
emb_size = self.embedding_size
|
792
781
|
model = \
|
@@ -1030,7 +1019,6 @@ class SeqBowEstimator(BaseEstimator):
|
|
1030
1019
|
tr_bow_counts = self._get_bow_wd_counts(train_data)
|
1031
1020
|
model.initialize_bias_terms(tr_bow_counts)
|
1032
1021
|
if self.npmi_matrix is not None:
|
1033
|
-
print("****** INITIALIZING NPMI LOSS FUNCTION *******")
|
1034
1022
|
model.initialize_npmi_loss(self.npmi_matrix)
|
1035
1023
|
return model
|
1036
1024
|
|
@@ -1057,7 +1045,6 @@ class SeqBowEstimator(BaseEstimator):
|
|
1057
1045
|
else:
|
1058
1046
|
config['latent_distribution'] = {'dist_type':'gaussian'}
|
1059
1047
|
config['epochs'] = self.epochs
|
1060
|
-
#config['embedding_source'] = self.embedding_source
|
1061
1048
|
config['gamma'] = self.gamma
|
1062
1049
|
config['warmup_ratio'] = self.warmup_ratio
|
1063
1050
|
config['llm_model_name'] = self.llm_model_name
|
@@ -1091,9 +1078,6 @@ class SeqBowEstimator(BaseEstimator):
|
|
1091
1078
|
log_interval, epoch_id, learning_rate):
|
1092
1079
|
"""Generate and print out the log message for training. """
|
1093
1080
|
if self.has_classifier:
|
1094
|
-
#metric_nm, metric_val = self.metric.compute()
|
1095
|
-
#if not isinstance(metric_nm, list):
|
1096
|
-
# metric_nm, metric_val = [metric_nm], [metric_val]
|
1097
1081
|
metric_nm = "AUPRC"
|
1098
1082
|
try:
|
1099
1083
|
metric_val = self.metric.compute()
|
@@ -1126,7 +1110,6 @@ class SeqBowEstimator(BaseEstimator):
|
|
1126
1110
|
rows = 0
|
1127
1111
|
for i, data in enumerate(dataloader):
|
1128
1112
|
seqs, = data
|
1129
|
-
#bow_batch = list(seqs[3].squeeze(axis=1))
|
1130
1113
|
bow_batch = list(seqs[3])
|
1131
1114
|
rows += len(bow_batch)
|
1132
1115
|
if i >= max_rows:
|
@@ -1170,10 +1153,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1170
1153
|
label_ls = label_ls.mean()
|
1171
1154
|
total_ls = (self.gamma * label_ls) + elbo_ls.mean()
|
1172
1155
|
if not self.multilabel:
|
1173
|
-
#label_ind = label.argmax(dim=0)
|
1174
|
-
#self.metric.update([out], [label_ind])
|
1175
1156
|
self.metric.update(torch.tensor(out), torch.tensor(label))
|
1176
|
-
#self.metric.update(torch.Tensor([out]), torch.Tensor([label_ind]))
|
1177
1157
|
else:
|
1178
1158
|
self.metric.update([out], [label])
|
1179
1159
|
else:
|
@@ -1214,7 +1194,6 @@ class SeqBowEstimator(BaseEstimator):
|
|
1214
1194
|
joint_loader = PairedDataLoader(train_data, aux_data)
|
1215
1195
|
num_train_steps = len(joint_loader) * self.epochs
|
1216
1196
|
|
1217
|
-
## The following from HuggingFace trainer.py lines 1047 to 1063
|
1218
1197
|
decay_parameters = get_parameter_names(model.llm, ALL_LAYERNORM_LAYERS)
|
1219
1198
|
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
1220
1199
|
non_llm_parameters = [name for name,_ in model.named_parameters() if not name.startswith("llm")]
|
@@ -1288,10 +1267,8 @@ class SeqBowEstimator(BaseEstimator):
|
|
1288
1267
|
if aux_batch is not None:
|
1289
1268
|
update_loss_details(total_ls_2, elbo_ls_2, red_ls_2, None)
|
1290
1269
|
|
1291
|
-
#debug
|
1292
|
-
|
1293
1270
|
if not accumulate or (batch_id + 1) % accumulate == 0:
|
1294
|
-
|
1271
|
+
torch.nn.utils.clip_grad.clip_grad_value_(model.llm.parameters(), 1.0)
|
1295
1272
|
optimizer.step()
|
1296
1273
|
dec_optimizer.step()
|
1297
1274
|
lr_scheduler.step()
|
tmnt/eval_npmi.py
CHANGED
@@ -115,6 +115,25 @@ class EvaluateNPMI(object):
|
|
115
115
|
npmi = (log10(n_docs) + log10(bigram_cnt) - log10(unigram_1) - log10(unigram_2)) / (log10(n_docs) - log10(bigram_cnt) + 1e-4)
|
116
116
|
npmi_matrix[w1, w2] = npmi
|
117
117
|
return npmi_matrix
|
118
|
+
|
119
|
+
class EvaluateNPMIUmass(object):
|
120
|
+
|
121
|
+
def __init__(self, npmi_matrix: np.array, vectorizer: TMNTVectorizer):
|
122
|
+
self.vectorizer = vectorizer
|
123
|
+
self.npmi_matrix = npmi_matrix # by convention this will be lower-triangular
|
124
|
+
dim = npmi_matrix.shape[0]
|
125
|
+
for mc in range(self.npmi_matrix.shape[0]):
|
126
|
+
for i in range(mc+1,dim):
|
127
|
+
self.npmi_matrix[mc,i] = self.npmi_matrix[i,mc]
|
128
|
+
|
129
|
+
def evaluate_topics(self, topic_ids):
|
130
|
+
npmi_score = 0.0
|
131
|
+
total_size = len(topic_ids) * len(topic_ids[0])
|
132
|
+
for topic in topic_ids:
|
133
|
+
for (w1, w2) in combinations(topic):
|
134
|
+
npmi_score += self.npmi_matrix[w1, w2]
|
135
|
+
return npmi_score / total_size
|
136
|
+
|
118
137
|
|
119
138
|
|
120
139
|
class FullNPMI(object):
|
@@ -1,9 +1,9 @@
|
|
1
1
|
tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
|
-
tmnt/data_loading.py,sha256=
|
3
|
+
tmnt/data_loading.py,sha256=vsAMyHGi3fuOFDmqo_zenNKOtVQiuqMHA-iPYWYpGKE,18873
|
4
4
|
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
|
-
tmnt/estimator.py,sha256=
|
6
|
-
tmnt/eval_npmi.py,sha256=
|
5
|
+
tmnt/estimator.py,sha256=htQ_JeUedEYWLPIBDbDhEL5deWtHiVNRKQN1528SybY,67751
|
6
|
+
tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
|
7
7
|
tmnt/inference.py,sha256=da8qAnjTDTuWQfPEOQewOfgikqE00XT1xGMiO2mckI4,15679
|
8
8
|
tmnt/modeling.py,sha256=O1V7ppU7J6pvESTvdEoV9BXbEF4Z-J1OHnRtszuagaA,29956
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.56.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.56.dist-info/METADATA,sha256=jk7-JlrqxLTACr0LsMoLGXT0nq0VVQIkWFoFNqYlEPE,1436
|
22
|
+
tmnt-0.7.56.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.56.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
tmnt-0.7.56.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.56.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|