tmnt 0.7.0b20230910__py3-none-any.whl → 0.7.0b20230912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/data_loading.py +12 -12
- tmnt/distribution.py +1 -1
- tmnt/estimator.py +3 -2
- tmnt/modeling.py +3 -3
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/METADATA +1 -1
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/RECORD +10 -10
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/LICENSE +0 -0
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/NOTICE +0 -0
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/WHEEL +0 -0
- {tmnt-0.7.0b20230910.dist-info → tmnt-0.7.0b20230912.dist-info}/top_level.txt +0 -0
tmnt/data_loading.py
CHANGED
@@ -92,7 +92,7 @@ def get_llm_paired_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_ma
|
|
92
92
|
|
93
93
|
class StratifiedPairedLLMLoader():
|
94
94
|
|
95
|
-
def __init__(self, data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, device='cpu'):
|
95
|
+
def __init__(self, data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, num_batches=0, device='cpu'):
|
96
96
|
self.data_a = data_a
|
97
97
|
self.data_b = data_b
|
98
98
|
self.bow_vectorizer = bow_vectorizer
|
@@ -102,7 +102,7 @@ class StratifiedPairedLLMLoader():
|
|
102
102
|
self.max_len_a = max_len_a
|
103
103
|
self.max_len_b = max_len_b
|
104
104
|
self.device = device
|
105
|
-
self.num_batches = max(len(data_a), len(data_b)) // batch_size
|
105
|
+
self.num_batches = num_batches or max(len(data_a), len(data_b)) // batch_size
|
106
106
|
self.stratified_sampler = StratifiedDualBatchSampler(np.array([label_map[l] for (l,_) in data_a]),
|
107
107
|
np.array([label_map[l] for (l,_) in data_b]),
|
108
108
|
batch_size,
|
@@ -419,22 +419,22 @@ class StratifiedDualBatchSampler:
|
|
419
419
|
self.shuffle = shuffle
|
420
420
|
self.batch_size = batch_size
|
421
421
|
self.num_batches = num_batches
|
422
|
-
counts_a = Counter(y_a)
|
423
|
-
counts_b = Counter(y_b)
|
422
|
+
self.counts_a = Counter(y_a)
|
423
|
+
self.counts_b = Counter(y_b)
|
424
424
|
self.class_weights_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
|
425
425
|
self.class_weights_b = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
|
426
|
-
for k in counts_a:
|
427
|
-
self.class_weights_a[k] = counts_a[k] / len(y_a)
|
428
|
-
for k in counts_b:
|
429
|
-
self.class_weights_b[k] = counts_b[k] / len(y_b)
|
426
|
+
for k in self.counts_a:
|
427
|
+
self.class_weights_a[k] = self.counts_a[k] / len(y_a)
|
428
|
+
for k in self.counts_b:
|
429
|
+
self.class_weights_b[k] = self.counts_b[k] / len(y_b)
|
430
430
|
self.class_indices_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1)
|
431
431
|
self.class_indices_b = [0] * (max(np.max(y_b), np.max(y_a)) + 1)
|
432
432
|
for i in range(len(self.class_indices_a)):
|
433
433
|
self.class_indices_a[i] = list(np.where(y_a == i)[0])
|
434
434
|
for i in range(len(self.class_indices_b)):
|
435
435
|
self.class_indices_b[i] = list(np.where(y_b == i)[0])
|
436
|
-
self.a_only = counts_a.keys() - counts_b.keys()
|
437
|
-
self.b_only = counts_b.keys() - counts_a.keys()
|
436
|
+
self.a_only = self.counts_a.keys() - self.counts_b.keys()
|
437
|
+
self.b_only = self.counts_b.keys() - self.counts_a.keys()
|
438
438
|
self.use_with_replacement = (self.batch_size > len(self.class_weights_a))
|
439
439
|
|
440
440
|
def _pop_leave_last(self, li):
|
@@ -451,14 +451,14 @@ class StratifiedDualBatchSampler:
|
|
451
451
|
for i in range(self.num_batches):
|
452
452
|
if i % 2 == 0:
|
453
453
|
classes_a = list(WeightedRandomSampler(self.class_weights_a, self.batch_size, replacement=self.use_with_replacement))
|
454
|
-
b_list = list(self.
|
454
|
+
b_list = list(self.counts_b)
|
455
455
|
random.shuffle(b_list)
|
456
456
|
classes_b = [ self._pop_leave_last(b_list) if a in self.a_only else a for a in classes_a]
|
457
457
|
batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a]
|
458
458
|
batch_indices_b = [ self.class_indices_b[c][next(samplers_b[c])] for c in classes_b]
|
459
459
|
else:
|
460
460
|
classes_b = list(WeightedRandomSampler(self.class_weights_b, self.batch_size, replacement=self.use_with_replacement))
|
461
|
-
a_list = list(self.
|
461
|
+
a_list = list(self.counts_a)
|
462
462
|
random.shuffle(a_list)
|
463
463
|
classes_a = [ self._pop_leave_last(a_list) if b in self.b_only else b for b in classes_b]
|
464
464
|
batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a]
|
tmnt/distribution.py
CHANGED
@@ -168,7 +168,7 @@ class LogisticGaussianDistribution(BaseDistribution):
|
|
168
168
|
class VonMisesDistribution(BaseDistribution):
|
169
169
|
|
170
170
|
def __init__(self, enc_size, n_latent, kappa=100.0, dr=0.1, device='cpu'):
|
171
|
-
super(VonMisesDistribution, self).__init__(enc_size, n_latent, device)
|
171
|
+
super(VonMisesDistribution, self).__init__(enc_size, n_latent, device, on_simplex=False)
|
172
172
|
self.device = device
|
173
173
|
self.kappa = kappa
|
174
174
|
self.kld_v = torch.tensor(VonMisesDistribution._vmf_kld(self.kappa, self.n_latent), device=device)
|
tmnt/estimator.py
CHANGED
@@ -1231,6 +1231,7 @@ class SeqBowEstimator(BaseEstimator):
|
|
1231
1231
|
llm_base_model = get_llm_model(self.llm_model_name).to(self.device)
|
1232
1232
|
model = SeqBowVED(llm_base_model, self.latent_distribution, num_classes=self.n_labels, device=self.device,
|
1233
1233
|
vocab_size = len(self.vocabulary), use_pooling = (self.llm_model_name.startswith("sentence-transformers")),
|
1234
|
+
entropy_loss_coef=self.entropy_loss_coef,
|
1234
1235
|
dropout=self.classifier_dropout)
|
1235
1236
|
return model
|
1236
1237
|
|
@@ -1583,10 +1584,10 @@ class SeqBowEstimator(BaseEstimator):
|
|
1583
1584
|
|
1584
1585
|
class SeqBowMetricEstimator(SeqBowEstimator):
|
1585
1586
|
|
1586
|
-
def __init__(self, *args, sdml_smoothing_factor=0.3, non_scoring_index=-1, **kwargs):
|
1587
|
+
def __init__(self, *args, sdml_smoothing_factor=0.3, metric_loss_temp=0.1, non_scoring_index=-1, **kwargs):
|
1587
1588
|
super(SeqBowMetricEstimator, self).__init__(*args, **kwargs)
|
1588
1589
|
#self.loss_function = GeneralizedSDMLLoss(smoothing_parameter=sdml_smoothing_factor, x2_downweight_idx=non_scoring_index)
|
1589
|
-
self.loss_function = MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor)
|
1590
|
+
self.loss_function = MultiNegativeCrossEntropyLoss(smoothing_parameter=sdml_smoothing_factor, metric_loss_temp=metric_loss_temp)
|
1590
1591
|
self.non_scoring_index = non_scoring_index ## if >=0 this will avoid considering this label index in evaluation
|
1591
1592
|
|
1592
1593
|
|
tmnt/modeling.py
CHANGED
@@ -506,7 +506,7 @@ class BaseSeqBowVED(BaseVAE):
|
|
506
506
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
507
507
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
508
508
|
else:
|
509
|
-
model_output.last_hidden_state[:,0,:]
|
509
|
+
return model_output.last_hidden_state[:,0,:]
|
510
510
|
|
511
511
|
def get_ordered_terms(self):
|
512
512
|
"""
|
@@ -551,8 +551,8 @@ class SeqBowVED(BaseSeqBowVED):
|
|
551
551
|
super(SeqBowVED, self).__init__(*args, **kwargs)
|
552
552
|
if self.has_classifier:
|
553
553
|
self.classifier = torch.nn.Sequential()
|
554
|
-
self.classifier.add_module("dr", nn.Dropout(self.dropout))
|
555
|
-
self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes))
|
554
|
+
self.classifier.add_module("dr", nn.Dropout(self.dropout).to(self.device))
|
555
|
+
self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes).to(self.device))
|
556
556
|
|
557
557
|
def forward(self, input_ids, attention_mask, bow=None): # pylint: disable=arguments-differ
|
558
558
|
llm_output = self.llm(input_ids, attention_mask)
|
@@ -2,12 +2,12 @@ tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
|
2
2
|
tmnt/bert_handling.py,sha256=4l78pzLjK0rbsGa3YxCsfVEndJPzaXTaj_928ZPZfSk,24677
|
3
3
|
tmnt/common_params.py,sha256=uNWs1UuaTx6xlxbS5LailnuXhuIKCg8kCJqxes-kAGY,2547
|
4
4
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
5
|
-
tmnt/data_loading.py,sha256=
|
6
|
-
tmnt/distribution.py,sha256
|
7
|
-
tmnt/estimator.py,sha256=
|
5
|
+
tmnt/data_loading.py,sha256=Fnn3Pdrw16e6IR_QEPusiUfSCrHlk-3ddKeyzQW_5JE,18569
|
6
|
+
tmnt/distribution.py,sha256=JrJe2HaF2uub0S8RxLAjSykg_AF6atvgUhNWbbaxaMo,8382
|
7
|
+
tmnt/estimator.py,sha256=lMcuVvPIBiI1ChUpllCFR6ygWnclpcOcYA-ht5r4W3s,78680
|
8
8
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
9
9
|
tmnt/inference.py,sha256=Hc0PRmUBLr9YbfqAGyw6-1BQqiwUUEtS8ehsWkr7QJk,18399
|
10
|
-
tmnt/modeling.py,sha256=
|
10
|
+
tmnt/modeling.py,sha256=Zw6F3PZcaJsUZVuyk-sQIMNWMaqcjQH5cetZTmcqI7g,32940
|
11
11
|
tmnt/selector.py,sha256=DWJlbdWKNxJmyLI9IRxCa3FDmaCilxTDzNfIf5mpBqc,9578
|
12
12
|
tmnt/trainer.py,sha256=xaJtU_vHAPbos9q86NNdnwz7kpUF5BxGyTWbDG_NBA0,25802
|
13
13
|
tmnt/classifier/__init__.py,sha256=1gLyJjCMHmnWdf-J4gnRs4uhbebtvCs9RgnZze1HTXY,67
|
@@ -30,9 +30,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
30
30
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
31
31
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
32
32
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
33
|
-
tmnt-0.7.
|
34
|
-
tmnt-0.7.
|
35
|
-
tmnt-0.7.
|
36
|
-
tmnt-0.7.
|
37
|
-
tmnt-0.7.
|
38
|
-
tmnt-0.7.
|
33
|
+
tmnt-0.7.0b20230912.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
34
|
+
tmnt-0.7.0b20230912.dist-info/METADATA,sha256=vx4UrdxOGQ2BPb9o-5mipGun99SVmyw0DOL4aS58gOU,997
|
35
|
+
tmnt-0.7.0b20230912.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
36
|
+
tmnt-0.7.0b20230912.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
37
|
+
tmnt-0.7.0b20230912.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
38
|
+
tmnt-0.7.0b20230912.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|