libmultilabel 0.6.2__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/PKG-INFO +2 -2
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/README.md +1 -1
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/linear.py +4 -1
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/tree.py +3 -1
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/utils.py +1 -0
- libmultilabel-0.7.0/libmultilabel/nn/attentionxml.py +800 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/model.py +6 -6
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/__init__.py +1 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/labelwise_attention_networks.py +69 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/modules.py +32 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/nn_utils.py +0 -1
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/PKG-INFO +2 -2
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/SOURCES.txt +1 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/setup.cfg +2 -2
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/LICENSE +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/__init__.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/common_utils.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/__init__.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/data_utils.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/metrics.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/preprocessor.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/logging.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/__init__.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/data_utils.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/metrics.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/bert.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/bert_attention.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/caml.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/kim_cnn.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/xml_cnn.py +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/dependency_links.txt +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/requires.txt +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/top_level.txt +0 -0
- {libmultilabel-0.6.2 → libmultilabel-0.7.0}/pyproject.toml +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: libmultilabel
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: A library for multi-label
|
|
3
|
+
Version: 0.7.0
|
|
4
|
+
Summary: A library for multi-class and multi-label classification
|
|
5
5
|
Home-page: https://github.com/ASUS-AICS/LibMultiLabel
|
|
6
6
|
Author: LibMultiLabel Team
|
|
7
7
|
License: MIT License
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# LibMultiLabel — a Library for Multi-class and Multi-label
|
|
1
|
+
# LibMultiLabel — a Library for Multi-class and Multi-label Classification
|
|
2
2
|
|
|
3
3
|
LibMultiLabel is a library for binary, multi-class, and multi-label classification. It has the following functionalities
|
|
4
4
|
|
|
@@ -17,10 +17,13 @@ __all__ = [
|
|
|
17
17
|
"predict_values",
|
|
18
18
|
"get_topk_labels",
|
|
19
19
|
"get_positive_labels",
|
|
20
|
+
"FlatModel",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class FlatModel:
|
|
25
|
+
"""A model returned from a training function."""
|
|
26
|
+
|
|
24
27
|
def __init__(
|
|
25
28
|
self,
|
|
26
29
|
name: str,
|
|
@@ -619,7 +622,7 @@ def train_binary_and_multiclass(
|
|
|
619
622
|
|
|
620
623
|
|
|
621
624
|
def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
|
|
622
|
-
"""Calculates the decision values associated with x.
|
|
625
|
+
"""Calculates the decision values associated with x, equivalent to model.predict_values(x).
|
|
623
626
|
|
|
624
627
|
Args:
|
|
625
628
|
model: A model returned from a training function.
|
|
@@ -10,7 +10,7 @@ from tqdm import tqdm
|
|
|
10
10
|
|
|
11
11
|
from . import linear
|
|
12
12
|
|
|
13
|
-
__all__ = ["train_tree"]
|
|
13
|
+
__all__ = ["train_tree", "TreeModel"]
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Node:
|
|
@@ -38,6 +38,8 @@ class Node:
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class TreeModel:
|
|
41
|
+
"""A model returned from train_tree."""
|
|
42
|
+
|
|
41
43
|
def __init__(
|
|
42
44
|
self,
|
|
43
45
|
root: Node,
|
|
@@ -0,0 +1,800 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from functools import partial
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Generator, Sequence, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from lightning import Trainer
|
|
11
|
+
from numpy import ndarray
|
|
12
|
+
from scipy.sparse import csr_matrix, csc_matrix, issparse
|
|
13
|
+
from scipy.special import expit
|
|
14
|
+
from sklearn.preprocessing import MultiLabelBinarizer, normalize
|
|
15
|
+
from torch import Tensor, is_tensor
|
|
16
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
|
+
|
|
21
|
+
from .data_utils import UNK
|
|
22
|
+
from ..common_utils import dump_log
|
|
23
|
+
from ..linear.preprocessor import Preprocessor
|
|
24
|
+
from ..nn import networks
|
|
25
|
+
from ..nn.model import Model
|
|
26
|
+
|
|
27
|
+
__all__ = ["PLTTrainer"]
|
|
28
|
+
|
|
29
|
+
from .nn_utils import init_trainer, init_model
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PLTTrainer:
|
|
35
|
+
CHECKPOINT_NAME = "model_"
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
config,
|
|
40
|
+
classes: Optional[list] = None,
|
|
41
|
+
embed_vecs: Optional[Tensor] = None,
|
|
42
|
+
word_dict: Optional[dict] = None,
|
|
43
|
+
):
|
|
44
|
+
# The number of levels is set to 2. In other words, there will be 2 models
|
|
45
|
+
self.multiclass = config.multiclass
|
|
46
|
+
if self.multiclass:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"The label space of multi-class datasets is usually not large, so PLT training is unnecessary."
|
|
49
|
+
"Please consider other methods."
|
|
50
|
+
"If you have a multi-class set with numerous labels, please let us know"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# cluster
|
|
54
|
+
self.cluster_size = config.cluster_size
|
|
55
|
+
# predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training
|
|
56
|
+
self.beam_width = config.beam_width
|
|
57
|
+
|
|
58
|
+
# dataset meta info
|
|
59
|
+
self.embed_vecs = embed_vecs
|
|
60
|
+
self.word_dict = word_dict
|
|
61
|
+
self.classes = classes
|
|
62
|
+
self.max_seq_length = config.max_seq_length
|
|
63
|
+
self.num_classes = len(classes)
|
|
64
|
+
|
|
65
|
+
# multilabel binarizer fitted to the datasets
|
|
66
|
+
self.binarizer = None
|
|
67
|
+
|
|
68
|
+
# cluster meta info
|
|
69
|
+
self.cluster_size = config.cluster_size
|
|
70
|
+
|
|
71
|
+
# network parameters
|
|
72
|
+
self.network_config = config.network_config
|
|
73
|
+
self.init_weight = "xavier_uniform" # AttentionXML-specific setting
|
|
74
|
+
self.loss_function = config.loss_function
|
|
75
|
+
|
|
76
|
+
# optimizer parameters
|
|
77
|
+
self.optimizer = config.optimizer
|
|
78
|
+
self.learning_rate = config.learning_rate
|
|
79
|
+
self.momentum = config.momentum
|
|
80
|
+
self.weight_decay = config.weight_decay
|
|
81
|
+
# learning rate scheduler
|
|
82
|
+
self.lr_scheduler = config.lr_scheduler
|
|
83
|
+
self.scheduler_config = config.scheduler_config
|
|
84
|
+
|
|
85
|
+
# Trainer parameters
|
|
86
|
+
self.use_cpu = config.cpu
|
|
87
|
+
self.accelerator = "cpu" if self.use_cpu else "gpu"
|
|
88
|
+
self.devices = 1
|
|
89
|
+
self.num_nodes = 1
|
|
90
|
+
self.epochs = config.epochs
|
|
91
|
+
self.limit_train_batches = config.limit_train_batches
|
|
92
|
+
self.limit_val_batches = config.limit_val_batches
|
|
93
|
+
self.limit_test_batches = config.limit_test_batches
|
|
94
|
+
|
|
95
|
+
# callbacks
|
|
96
|
+
self.silent = config.silent
|
|
97
|
+
# EarlyStopping
|
|
98
|
+
self.early_stopping_metric = config.early_stopping_metric
|
|
99
|
+
self.patience = config.patience
|
|
100
|
+
# ModelCheckpoint
|
|
101
|
+
self.val_metric = config.val_metric
|
|
102
|
+
self.checkpoint_dir = Path(config.checkpoint_dir)
|
|
103
|
+
|
|
104
|
+
self.metrics = config.monitor_metrics
|
|
105
|
+
self.metric_threshold = config.metric_threshold
|
|
106
|
+
self.monitor_metrics = config.monitor_metrics
|
|
107
|
+
|
|
108
|
+
# dataloader parameters
|
|
109
|
+
# whether shuffle the training dataset or not during the training process
|
|
110
|
+
self.shuffle = config.shuffle
|
|
111
|
+
pin_memory = True if self.accelerator == "gpu" else False
|
|
112
|
+
# training DataLoader
|
|
113
|
+
self.dataloader = partial(
|
|
114
|
+
DataLoader,
|
|
115
|
+
batch_size=config.batch_size,
|
|
116
|
+
num_workers=config.data_workers,
|
|
117
|
+
pin_memory=pin_memory,
|
|
118
|
+
)
|
|
119
|
+
# evaluation DataLoader
|
|
120
|
+
self.eval_dataloader = partial(
|
|
121
|
+
self.dataloader,
|
|
122
|
+
batch_size=config.eval_batch_size,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# predict
|
|
126
|
+
self.save_k_predictions = config.save_k_predictions
|
|
127
|
+
|
|
128
|
+
# save path
|
|
129
|
+
self.log_path = config.log_path
|
|
130
|
+
self.predict_out_path = config.predict_out_path
|
|
131
|
+
self.config = config
|
|
132
|
+
|
|
133
|
+
def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]:
|
|
134
|
+
"""Map labels to their corresponding clusters in CSR sparse format.
|
|
135
|
+
Notice that this function deals with SPARSE matrix.
|
|
136
|
+
Assume there are 6 labels clustered as [(0, 1), (2, 3), (4, 5)]. Here (0, 1) is cluster with index 0 and so on.
|
|
137
|
+
Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2].
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
|
|
141
|
+
*labels (csr_matrix): labels in CSR sparse format.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Generator[csr_matrix]: resulting clusters converted from labels in CSR sparse format
|
|
145
|
+
"""
|
|
146
|
+
mapping = np.empty(self.num_classes, dtype=np.uint32)
|
|
147
|
+
for idx, clusters in enumerate(cluster_mapping):
|
|
148
|
+
mapping[clusters] = idx
|
|
149
|
+
|
|
150
|
+
def _label2cluster(label: csr_matrix) -> csr_matrix:
|
|
151
|
+
row = []
|
|
152
|
+
col = []
|
|
153
|
+
data = []
|
|
154
|
+
for i in range(label.shape[0]):
|
|
155
|
+
# n include all mapped ancestor clusters
|
|
156
|
+
n = np.unique(mapping[label.indices[label.indptr[i] : label.indptr[i + 1]]])
|
|
157
|
+
row += [i] * len(n)
|
|
158
|
+
col += n.tolist()
|
|
159
|
+
data += [1] * len(n)
|
|
160
|
+
return csr_matrix((data, (row, col)), shape=(label.shape[0], len(cluster_mapping)))
|
|
161
|
+
|
|
162
|
+
return (_label2cluster(label) for label in labels)
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def cluster2label(cluster_mapping, clusters, cluster_scores=None):
|
|
166
|
+
"""Expand clusters to their corresponding labels and, if available, assign scores to each label.
|
|
167
|
+
Labels inside the same cluster have the same scores. This function is applied to predictions from model 0.
|
|
168
|
+
Notice that the behaviors of this function are different from label2cluster.
|
|
169
|
+
Also notice that this function deals with DENSE matrix.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
|
|
173
|
+
clusters (np.ndarray): predicted clusters from model 0.
|
|
174
|
+
cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Generator[np.ndarray]: resulting labels expanded from clusters
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
labels_selected = []
|
|
181
|
+
|
|
182
|
+
if cluster_scores is not None:
|
|
183
|
+
# label_scores are corresponding scores for selected labels and
|
|
184
|
+
# shape: (len(x), cluster_size * top_k)
|
|
185
|
+
label_scores = []
|
|
186
|
+
for score, cluster in zip(cluster_scores, clusters):
|
|
187
|
+
label_scores += [np.repeat(score, [len(labels) for labels in cluster_mapping[cluster]])]
|
|
188
|
+
labels_selected += [np.concatenate(cluster_mapping[cluster])]
|
|
189
|
+
return labels_selected, label_scores
|
|
190
|
+
else:
|
|
191
|
+
labels_selected = [np.concatenate(cluster_mapping[cluster]) for cluster in clusters]
|
|
192
|
+
return labels_selected
|
|
193
|
+
|
|
194
|
+
def fit(self, datasets):
|
|
195
|
+
"""fit model to the training dataset
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
datasets: dict containing training, validation, and/or test datasets
|
|
199
|
+
"""
|
|
200
|
+
if self.get_best_model_path(level=1).exists():
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
# AttentionXML-specific data preprocessing
|
|
204
|
+
train_val_dataset = datasets["train"] + datasets["val"]
|
|
205
|
+
train_val_dataset = {
|
|
206
|
+
"x": [" ".join(i["text"]) for i in train_val_dataset],
|
|
207
|
+
"y": [i["label"] for i in train_val_dataset],
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
# Preprocessor does tf-idf vectorization and multilabel binarization
|
|
211
|
+
# For details, see libmultilabel.linear.preprocessor.Preprocessor
|
|
212
|
+
preprocessor = Preprocessor()
|
|
213
|
+
datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes}
|
|
214
|
+
# Preprocessor requires the input dictionary to has a key named "train" and will return a new dictionary with
|
|
215
|
+
# the same key.
|
|
216
|
+
train_val_dataset_tf = preprocessor.fit_transform(datasets_temp)["train"]
|
|
217
|
+
# save binarizer for testing
|
|
218
|
+
self.binarizer = preprocessor.binarizer
|
|
219
|
+
|
|
220
|
+
train_x = self.reformat_text(datasets["train"])
|
|
221
|
+
val_x = self.reformat_text(datasets["val"])
|
|
222
|
+
|
|
223
|
+
train_y = train_val_dataset_tf["y"][: len(datasets["train"])]
|
|
224
|
+
val_y = train_val_dataset_tf["y"][len(datasets["train"]) :]
|
|
225
|
+
|
|
226
|
+
# clusters are saved to the disk so that users doesn't need to provide the original training data when they want
|
|
227
|
+
# to do predicting solely
|
|
228
|
+
build_label_tree(
|
|
229
|
+
sparse_x=train_val_dataset_tf["x"],
|
|
230
|
+
sparse_y=train_val_dataset_tf["y"],
|
|
231
|
+
cluster_size=self.cluster_size,
|
|
232
|
+
output_dir=self.checkpoint_dir,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
clusters = np.load(self.get_cluster_path(), allow_pickle=True)
|
|
236
|
+
|
|
237
|
+
# map each y to the parent cluster indices
|
|
238
|
+
train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y)
|
|
239
|
+
|
|
240
|
+
trainer = init_trainer(
|
|
241
|
+
self.checkpoint_dir,
|
|
242
|
+
epochs=self.epochs,
|
|
243
|
+
patience=self.patience,
|
|
244
|
+
early_stopping_metric=self.early_stopping_metric,
|
|
245
|
+
val_metric=self.val_metric,
|
|
246
|
+
silent=self.silent,
|
|
247
|
+
use_cpu=self.use_cpu,
|
|
248
|
+
limit_train_batches=self.limit_train_batches,
|
|
249
|
+
limit_val_batches=self.limit_val_batches,
|
|
250
|
+
limit_test_batches=self.limit_test_batches,
|
|
251
|
+
save_checkpoints=True,
|
|
252
|
+
)
|
|
253
|
+
trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}0"
|
|
254
|
+
|
|
255
|
+
train_dataloader = self.dataloader(PlainDataset(train_x, train_y_clustered), shuffle=self.shuffle)
|
|
256
|
+
val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered))
|
|
257
|
+
|
|
258
|
+
best_model_path = self.get_best_model_path(level=0)
|
|
259
|
+
if not best_model_path.exists():
|
|
260
|
+
model_0 = init_model(
|
|
261
|
+
model_name="AttentionXML_0",
|
|
262
|
+
network_config=self.network_config,
|
|
263
|
+
classes=clusters,
|
|
264
|
+
word_dict=self.word_dict,
|
|
265
|
+
embed_vecs=self.embed_vecs,
|
|
266
|
+
init_weight=self.init_weight,
|
|
267
|
+
log_path=self.log_path,
|
|
268
|
+
learning_rate=self.learning_rate,
|
|
269
|
+
optimizer=self.optimizer,
|
|
270
|
+
momentum=self.momentum,
|
|
271
|
+
weight_decay=self.weight_decay,
|
|
272
|
+
lr_scheduler=self.lr_scheduler,
|
|
273
|
+
scheduler_config=self.scheduler_config,
|
|
274
|
+
val_metric=self.val_metric,
|
|
275
|
+
metric_threshold=self.metric_threshold,
|
|
276
|
+
monitor_metrics=self.monitor_metrics,
|
|
277
|
+
multiclass=self.multiclass,
|
|
278
|
+
loss_function=self.loss_function,
|
|
279
|
+
silent=self.silent,
|
|
280
|
+
save_k_predictions=self.beam_width,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
logger.info(f"Training level 0. Number of clusters: {len(clusters)}")
|
|
284
|
+
trainer.fit(model_0, train_dataloader, val_dataloader)
|
|
285
|
+
logger.info(f"Finish training level 0")
|
|
286
|
+
|
|
287
|
+
logger.info(f"Best model loaded from {best_model_path}")
|
|
288
|
+
model_0 = Model.load_from_checkpoint(best_model_path)
|
|
289
|
+
|
|
290
|
+
logger.info(
|
|
291
|
+
f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and "
|
|
292
|
+
f"extract labels from these clusters for level 1 training."
|
|
293
|
+
)
|
|
294
|
+
# load training and validation data and predict corresponding level 0 clusters
|
|
295
|
+
train_dataloader = self.dataloader(PlainDataset(train_x))
|
|
296
|
+
val_dataloader = self.dataloader(PlainDataset(val_x))
|
|
297
|
+
|
|
298
|
+
train_pred = trainer.predict(model_0, train_dataloader)
|
|
299
|
+
val_pred = trainer.predict(model_0, val_dataloader)
|
|
300
|
+
|
|
301
|
+
train_clusters_pred = np.vstack([i["top_k_pred"] for i in train_pred])
|
|
302
|
+
val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred]))
|
|
303
|
+
val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred])
|
|
304
|
+
|
|
305
|
+
train_clusters_selected = np.empty((len(train_x), self.beam_width), dtype=np.uint)
|
|
306
|
+
for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")):
|
|
307
|
+
# relevant clusters are positive
|
|
308
|
+
pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]])
|
|
309
|
+
# Select relevant clusters first. Then from top-predicted clusters, sequentially include them until
|
|
310
|
+
# cluster number reaches beam width
|
|
311
|
+
if len(pos) <= self.beam_width:
|
|
312
|
+
selected = pos
|
|
313
|
+
for y in ys:
|
|
314
|
+
y = y.item()
|
|
315
|
+
if len(selected) == self.beam_width:
|
|
316
|
+
break
|
|
317
|
+
selected.add(y)
|
|
318
|
+
# Regard positive (true) label as samples iff they appear in the predicted labels
|
|
319
|
+
# if the number of positive labels is more than top_k. If samples are not of length top_k
|
|
320
|
+
# add unseen predicted labels until reaching top_k.
|
|
321
|
+
else:
|
|
322
|
+
selected = set()
|
|
323
|
+
for y in ys:
|
|
324
|
+
y = y.item()
|
|
325
|
+
if y in pos:
|
|
326
|
+
selected.add(y)
|
|
327
|
+
if len(selected) == self.beam_width:
|
|
328
|
+
break
|
|
329
|
+
if len(selected) < self.beam_width:
|
|
330
|
+
selected = (list(selected) + list(pos - selected))[: self.beam_width]
|
|
331
|
+
train_clusters_selected[i] = np.asarray(list(selected))
|
|
332
|
+
|
|
333
|
+
train_labels_selected = PLTTrainer.cluster2label(clusters, train_clusters_selected)
|
|
334
|
+
val_labels_pred, val_scores_pred = PLTTrainer.cluster2label(clusters, val_clusters_pred, val_scores_pred)
|
|
335
|
+
num_labels_selected = self.beam_width * max(len(c) for c in clusters)
|
|
336
|
+
|
|
337
|
+
trainer = init_trainer(
|
|
338
|
+
self.checkpoint_dir,
|
|
339
|
+
epochs=self.epochs,
|
|
340
|
+
patience=self.patience,
|
|
341
|
+
early_stopping_metric=self.val_metric,
|
|
342
|
+
val_metric=self.val_metric,
|
|
343
|
+
silent=self.silent,
|
|
344
|
+
use_cpu=self.use_cpu,
|
|
345
|
+
limit_train_batches=self.limit_train_batches,
|
|
346
|
+
limit_val_batches=self.limit_val_batches,
|
|
347
|
+
limit_test_batches=self.limit_test_batches,
|
|
348
|
+
save_checkpoints=True,
|
|
349
|
+
)
|
|
350
|
+
trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}1"
|
|
351
|
+
|
|
352
|
+
# train & val dataloaders for training
|
|
353
|
+
train_dataloader = self.dataloader(
|
|
354
|
+
PLTDataset(
|
|
355
|
+
train_x,
|
|
356
|
+
train_y,
|
|
357
|
+
num_classes=self.num_classes,
|
|
358
|
+
num_labels_selected=num_labels_selected,
|
|
359
|
+
labels_selected=train_labels_selected,
|
|
360
|
+
),
|
|
361
|
+
shuffle=self.shuffle,
|
|
362
|
+
)
|
|
363
|
+
val_dataloader = self.dataloader(
|
|
364
|
+
PLTDataset(
|
|
365
|
+
val_x,
|
|
366
|
+
val_y,
|
|
367
|
+
num_classes=self.num_classes,
|
|
368
|
+
num_labels_selected=num_labels_selected,
|
|
369
|
+
labels_selected=val_labels_pred,
|
|
370
|
+
label_scores=val_scores_pred,
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
network = getattr(networks, "AttentionXML_1")(
|
|
376
|
+
embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config)
|
|
377
|
+
)
|
|
378
|
+
except Exception:
|
|
379
|
+
raise AttributeError("Failed to initialize AttentionXML")
|
|
380
|
+
|
|
381
|
+
model_1 = PLTModel(
|
|
382
|
+
classes=self.classes,
|
|
383
|
+
word_dict=self.word_dict,
|
|
384
|
+
network=network,
|
|
385
|
+
log_path=self.log_path,
|
|
386
|
+
learning_rate=self.learning_rate,
|
|
387
|
+
optimizer=self.optimizer,
|
|
388
|
+
momentum=self.momentum,
|
|
389
|
+
weight_decay=self.weight_decay,
|
|
390
|
+
lr_scheduler=self.lr_scheduler,
|
|
391
|
+
scheduler_config=self.scheduler_config,
|
|
392
|
+
val_metric=self.val_metric,
|
|
393
|
+
metric_threshold=self.metric_threshold,
|
|
394
|
+
monitor_metrics=self.monitor_metrics,
|
|
395
|
+
multiclass=self.multiclass,
|
|
396
|
+
loss_function=self.loss_function,
|
|
397
|
+
silent=self.silent,
|
|
398
|
+
save_k_predictions=self.save_k_predictions,
|
|
399
|
+
)
|
|
400
|
+
logger.info(f"Initialize model with weights from level 0")
|
|
401
|
+
# For weights not initialized by the level-0 model, use xavier uniform initialization
|
|
402
|
+
torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight)
|
|
403
|
+
# As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately
|
|
404
|
+
model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict())
|
|
405
|
+
model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict())
|
|
406
|
+
model_1.network.output.load_state_dict(model_0.network.output.state_dict())
|
|
407
|
+
|
|
408
|
+
del model_0
|
|
409
|
+
|
|
410
|
+
logger.info(
|
|
411
|
+
f"Training level 1. Number of labels: {self.num_classes}."
|
|
412
|
+
f"Number of labels selected: {train_dataloader.dataset.num_labels_selected}"
|
|
413
|
+
)
|
|
414
|
+
trainer.fit(model_1, train_dataloader, val_dataloader)
|
|
415
|
+
logger.info(f"Best model loaded from {best_model_path}")
|
|
416
|
+
logger.info(f"Finish training level 1")
|
|
417
|
+
|
|
418
|
+
def test(self, dataset):
|
|
419
|
+
# retrieve word_dict from model_1
|
|
420
|
+
# prediction starts from level 0
|
|
421
|
+
model_0 = Model.load_from_checkpoint(
|
|
422
|
+
self.get_best_model_path(level=0),
|
|
423
|
+
save_k_predictions=self.beam_width,
|
|
424
|
+
)
|
|
425
|
+
model_1 = PLTModel.load_from_checkpoint(
|
|
426
|
+
self.get_best_model_path(level=1),
|
|
427
|
+
save_k_predictions=self.save_k_predictions,
|
|
428
|
+
metrics=self.metrics,
|
|
429
|
+
)
|
|
430
|
+
self.word_dict = model_1.word_dict
|
|
431
|
+
classes = model_1.classes
|
|
432
|
+
|
|
433
|
+
test_x = self.reformat_text(dataset)
|
|
434
|
+
|
|
435
|
+
if self.binarizer is None:
|
|
436
|
+
binarizer = MultiLabelBinarizer(classes=classes, sparse_output=True)
|
|
437
|
+
binarizer.fit(None)
|
|
438
|
+
test_y = binarizer.transform((i["label"] for i in dataset))
|
|
439
|
+
else:
|
|
440
|
+
test_y = self.binarizer.transform((i["label"] for i in dataset))
|
|
441
|
+
logger.info("Testing process started")
|
|
442
|
+
trainer = Trainer(
|
|
443
|
+
devices=1,
|
|
444
|
+
accelerator=self.accelerator,
|
|
445
|
+
logger=False,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
test_dataloader = self.eval_dataloader(PlainDataset(test_x))
|
|
449
|
+
|
|
450
|
+
logger.info(f"Predicting level 0. Number of clusters: {self.beam_width}")
|
|
451
|
+
test_pred = trainer.predict(model_0, test_dataloader)
|
|
452
|
+
test_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred]))
|
|
453
|
+
test_clusters_pred = np.vstack([i["top_k_pred"] for i in test_pred])
|
|
454
|
+
|
|
455
|
+
clusters = np.load(self.get_cluster_path(), allow_pickle=True)
|
|
456
|
+
test_labels_pred, test_scores_pred = PLTTrainer.cluster2label(clusters, test_clusters_pred, test_scores_pred)
|
|
457
|
+
num_labels_selected = self.beam_width * max(len(c) for c in clusters)
|
|
458
|
+
|
|
459
|
+
test_dataloader = self.eval_dataloader(
|
|
460
|
+
PLTDataset(
|
|
461
|
+
test_x,
|
|
462
|
+
test_y,
|
|
463
|
+
num_classes=self.num_classes,
|
|
464
|
+
num_labels_selected=num_labels_selected,
|
|
465
|
+
labels_selected=test_labels_pred,
|
|
466
|
+
label_scores=test_scores_pred,
|
|
467
|
+
),
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
logger.info(f"Testing level 1")
|
|
471
|
+
trainer.test(model_1, test_dataloader)
|
|
472
|
+
logger.info("Testing process finished")
|
|
473
|
+
|
|
474
|
+
if self.save_k_predictions > 0:
|
|
475
|
+
batch_predictions = trainer.predict(model_1, test_dataloader)
|
|
476
|
+
pred_labels = np.vstack([batch["top_k_pred"] for batch in batch_predictions])
|
|
477
|
+
pred_scores = np.vstack([batch["top_k_pred_scores"] for batch in batch_predictions])
|
|
478
|
+
with open(self.predict_out_path, "w") as fp:
|
|
479
|
+
for pred_label, pred_score in zip(pred_labels, pred_scores):
|
|
480
|
+
out_str = " ".join(
|
|
481
|
+
[f"{model_1.classes[label]}:{score:.4}" for label, score in zip(pred_label, pred_score)]
|
|
482
|
+
)
|
|
483
|
+
fp.write(out_str + "\n")
|
|
484
|
+
logging.info(f"Saved predictions to: {self.predict_out_path}")
|
|
485
|
+
|
|
486
|
+
dump_log(self.log_path, config=self.config)
|
|
487
|
+
|
|
488
|
+
def reformat_text(self, dataset):
|
|
489
|
+
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
|
|
490
|
+
encoded_text = list(
|
|
491
|
+
map(
|
|
492
|
+
lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)
|
|
493
|
+
if text
|
|
494
|
+
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
|
|
495
|
+
[instance["text"][: self.max_seq_length] for instance in dataset],
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
# pad the first entry to be of length 500 if necessary
|
|
499
|
+
encoded_text[0] = torch.cat(
|
|
500
|
+
(
|
|
501
|
+
encoded_text[0],
|
|
502
|
+
torch.tensor(0, dtype=torch.int64).repeat(self.max_seq_length - encoded_text[0].shape[0]),
|
|
503
|
+
)
|
|
504
|
+
)
|
|
505
|
+
encoded_text = pad_sequence(encoded_text, batch_first=True)
|
|
506
|
+
return encoded_text
|
|
507
|
+
|
|
508
|
+
def get_best_model_path(self, level: int) -> Path:
|
|
509
|
+
return self.checkpoint_dir / f"{self.CHECKPOINT_NAME}{level}{ModelCheckpoint.FILE_EXTENSION}"
|
|
510
|
+
|
|
511
|
+
def get_cluster_path(self) -> Path:
|
|
512
|
+
return self.checkpoint_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}"
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
###################################### Model ######################################
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class PLTModel(Model):
|
|
519
|
+
def __init__(
|
|
520
|
+
self,
|
|
521
|
+
classes,
|
|
522
|
+
word_dict,
|
|
523
|
+
network,
|
|
524
|
+
loss_function="binary_cross_entropy_with_logits",
|
|
525
|
+
log_path=None,
|
|
526
|
+
**kwargs,
|
|
527
|
+
):
|
|
528
|
+
super().__init__(
|
|
529
|
+
classes=classes,
|
|
530
|
+
word_dict=word_dict,
|
|
531
|
+
network=network,
|
|
532
|
+
loss_function=loss_function,
|
|
533
|
+
log_path=log_path,
|
|
534
|
+
**kwargs,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def scatter_logits(
|
|
538
|
+
self,
|
|
539
|
+
logits: Tensor,
|
|
540
|
+
labels_selected: Tensor,
|
|
541
|
+
label_scores: Tensor,
|
|
542
|
+
) -> Tensor:
|
|
543
|
+
"""For each instance, we only have predictions on selected labels. This subroutine maps these predictions to
|
|
544
|
+
the whole label space. The scores of unsampled labels are set to 0."""
|
|
545
|
+
src = torch.sigmoid(logits.detach()) * label_scores
|
|
546
|
+
# During validation/testing, many fake labels might exist in a batch for the purpose of padding.
|
|
547
|
+
# A fake label has index len(classes) and does not belong to the real label space.
|
|
548
|
+
preds = torch.zeros(
|
|
549
|
+
labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype
|
|
550
|
+
)
|
|
551
|
+
preds.scatter_(dim=1, index=labels_selected, src=src)
|
|
552
|
+
# slicing removes fake labels whose index is exactly len(self.classes)
|
|
553
|
+
# afterwards, preds is restored to the real label space
|
|
554
|
+
preds = preds[:, :-1]
|
|
555
|
+
return preds
|
|
556
|
+
|
|
557
|
+
def shared_step(self, batch):
|
|
558
|
+
"""Return loss and predicted logits of the network.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
batch (dict): A batch of text and label.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
loss (torch.Tensor): Loss between target and predict logits.
|
|
565
|
+
pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
|
|
566
|
+
"""
|
|
567
|
+
y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1)
|
|
568
|
+
pred_logits = self(batch)
|
|
569
|
+
loss = self.loss_function(pred_logits, y.float())
|
|
570
|
+
return loss, pred_logits
|
|
571
|
+
|
|
572
|
+
def _shared_eval_step(self, batch, batch_idx):
|
|
573
|
+
logits = self(batch)
|
|
574
|
+
logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"])
|
|
575
|
+
self.eval_metric.update(logits, batch["label"].long())
|
|
576
|
+
|
|
577
|
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
578
|
+
logits = self(batch)
|
|
579
|
+
scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.save_k_predictions)
|
|
580
|
+
# This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned
|
|
581
|
+
logits = torch.logit(scores)
|
|
582
|
+
return {
|
|
583
|
+
"top_k_pred": torch.take_along_dim(batch["labels_selected"], labels, dim=1).numpy(force=True),
|
|
584
|
+
"top_k_pred_scores": logits.numpy(force=True),
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
###################################### Dataset ######################################
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
class PlainDataset(Dataset):
|
|
592
|
+
"""Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset.
|
|
593
|
+
WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset
|
|
594
|
+
does not.
|
|
595
|
+
Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does
|
|
596
|
+
this while generating clusters. There is no need to do multilabel binarization again.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
x (list | ndarray | Tensor): texts.
|
|
600
|
+
y (Optional: csr_matrix | ndarray | Tensor): labels.
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
def __init__(self, x, y=None):
|
|
604
|
+
if y is not None:
|
|
605
|
+
assert len(x) == y.shape[0], "Sizes mismatch between texts and labels"
|
|
606
|
+
self.x = x
|
|
607
|
+
self.y = y
|
|
608
|
+
|
|
609
|
+
def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]:
|
|
610
|
+
item = {"text": self.x[idx]}
|
|
611
|
+
|
|
612
|
+
# train/val/test
|
|
613
|
+
if self.y is not None:
|
|
614
|
+
if issparse(self.y):
|
|
615
|
+
y = self.y[idx].toarray().squeeze(0).astype(np.int32)
|
|
616
|
+
elif isinstance(self.y, ndarray):
|
|
617
|
+
y = self.y[idx].astype(np.int32)
|
|
618
|
+
elif is_tensor(self.y):
|
|
619
|
+
y = self.y[idx].int()
|
|
620
|
+
else:
|
|
621
|
+
raise TypeError(
|
|
622
|
+
"The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor."
|
|
623
|
+
f"But got {type(self.y)} instead."
|
|
624
|
+
)
|
|
625
|
+
item["label"] = y
|
|
626
|
+
return item
|
|
627
|
+
|
|
628
|
+
def __len__(self):
|
|
629
|
+
return len(self.x)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
class PLTDataset(PlainDataset):
|
|
633
|
+
"""Dataset for model_1 of AttentionXML.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
x: texts.
|
|
637
|
+
y: labels.
|
|
638
|
+
num_classes: number of classes.
|
|
639
|
+
num_labels_selected: the number of selected labels.
|
|
640
|
+
labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k).
|
|
641
|
+
label_scores: scores for each label. Shape: (len(x), predict_top_k).
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
def __init__(
|
|
645
|
+
self,
|
|
646
|
+
x,
|
|
647
|
+
y: Optional[csr_matrix | ndarray] = None,
|
|
648
|
+
*,
|
|
649
|
+
num_classes: int,
|
|
650
|
+
num_labels_selected: int,
|
|
651
|
+
labels_selected: ndarray | Tensor,
|
|
652
|
+
label_scores: Optional[ndarray | Tensor] = None,
|
|
653
|
+
):
|
|
654
|
+
super().__init__(x, y)
|
|
655
|
+
self.num_classes = num_classes
|
|
656
|
+
self.num_labels_selected = num_labels_selected
|
|
657
|
+
self.labels_selected = labels_selected
|
|
658
|
+
self.label_scores = label_scores
|
|
659
|
+
|
|
660
|
+
def __getitem__(self, idx: int):
|
|
661
|
+
item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])}
|
|
662
|
+
|
|
663
|
+
if self.y is not None:
|
|
664
|
+
item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32)
|
|
665
|
+
|
|
666
|
+
# PyTorch requires inputs to be of the same shape. Pad any instance with length below num_labels_selected by
|
|
667
|
+
# randomly selecting labels.
|
|
668
|
+
# training
|
|
669
|
+
if self.label_scores is None:
|
|
670
|
+
# randomly add real labels when the number is below num_labels_selected
|
|
671
|
+
# some labels might be selected more than once
|
|
672
|
+
if len(item["labels_selected"]) < self.num_labels_selected:
|
|
673
|
+
samples = np.random.randint(
|
|
674
|
+
self.num_classes,
|
|
675
|
+
size=self.num_labels_selected - len(item["labels_selected"]),
|
|
676
|
+
)
|
|
677
|
+
item["labels_selected"] = np.concatenate([item["labels_selected"], samples])
|
|
678
|
+
|
|
679
|
+
# val/test/pred
|
|
680
|
+
else:
|
|
681
|
+
item["label_scores"] = self.label_scores[idx]
|
|
682
|
+
# add fake labels when the number of labels is below num_labels_selected
|
|
683
|
+
if len(item["labels_selected"]) < self.num_labels_selected:
|
|
684
|
+
item["label_scores"] = np.concatenate(
|
|
685
|
+
[
|
|
686
|
+
item["label_scores"],
|
|
687
|
+
[-np.inf] * (self.num_labels_selected - len(item["labels_selected"])),
|
|
688
|
+
]
|
|
689
|
+
)
|
|
690
|
+
item["labels_selected"] = np.concatenate(
|
|
691
|
+
[
|
|
692
|
+
item["labels_selected"],
|
|
693
|
+
[self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])),
|
|
694
|
+
]
|
|
695
|
+
)
|
|
696
|
+
return item
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
###################################### Cluster ######################################
|
|
700
|
+
|
|
701
|
+
CLUSTER_NAME = "label_clusters"
|
|
702
|
+
CLUSTER_FILE_EXTENSION = FILE_EXTENSION = ".npy"
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path):
|
|
706
|
+
"""Build a binary tree to group labels into clusters, each of which contains up tp cluster_size labels. The tree has
|
|
707
|
+
several layers; nodes in the last layer correspond to the output clusters.
|
|
708
|
+
Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters look something like:
|
|
709
|
+
((0, 2), (1, 3), (4, 5)).
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
sparse_x: features extracted from texts in CSR sparse format
|
|
713
|
+
sparse_y: binarized labels in CSR sparse format
|
|
714
|
+
cluster_size: the maximum number of labels within each cluster
|
|
715
|
+
output_dir: directory to store the clustering file
|
|
716
|
+
"""
|
|
717
|
+
# skip constructing label tree if the output file already exists
|
|
718
|
+
output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir)
|
|
719
|
+
cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}"
|
|
720
|
+
if cluster_path.exists():
|
|
721
|
+
logger.info("Clustering has finished in a previous run")
|
|
722
|
+
return
|
|
723
|
+
|
|
724
|
+
# meta info
|
|
725
|
+
logger.info("Label clustering started")
|
|
726
|
+
logger.info(f"Cluster size: {cluster_size}")
|
|
727
|
+
# The height of the tree satisfies the following inequality:
|
|
728
|
+
# 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size
|
|
729
|
+
height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size)))
|
|
730
|
+
logger.info(f"Labels will be grouped into {2 ** height} clusters")
|
|
731
|
+
|
|
732
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
733
|
+
|
|
734
|
+
# For each label, sum up normalized instances relevant to the label and normalize to get the label representation
|
|
735
|
+
label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x)))
|
|
736
|
+
|
|
737
|
+
# clustering by a binary tree:
|
|
738
|
+
# at each layer split each cluster to two. Leave nodes correspond to the obtained clusters.
|
|
739
|
+
clusters = [np.arange(sparse_y.shape[1])]
|
|
740
|
+
for _ in range(height):
|
|
741
|
+
next_clusters = []
|
|
742
|
+
for cluster in clusters:
|
|
743
|
+
next_clusters += _split_cluster(cluster, label_repr[cluster])
|
|
744
|
+
clusters = next_clusters
|
|
745
|
+
logger.info(f"Having grouped {len(clusters)} clusters")
|
|
746
|
+
|
|
747
|
+
np.save(cluster_path, np.asarray(clusters, dtype=object))
|
|
748
|
+
logger.info(f"Label clustering finished. Saving results to {cluster_path}")
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]:
|
|
752
|
+
"""A variant of KMeans implemented in AttentionXML. Here K = 2. The cluster is partitioned into two groups, each
|
|
753
|
+
with approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are:
|
|
754
|
+
1. the distance metric is cosine similarity.
|
|
755
|
+
2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
cluster: a subset of labels
|
|
759
|
+
label_repr: the normalized representations of the relationship between labels and texts of the given cluster
|
|
760
|
+
"""
|
|
761
|
+
# Randomly choose two points as initial centroids and obtain their label representations
|
|
762
|
+
centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray()
|
|
763
|
+
|
|
764
|
+
# Initialize distances (cosine similarity)
|
|
765
|
+
# Cosine similarity always falls to the interval [-1, 1]
|
|
766
|
+
old_dist = -2.0
|
|
767
|
+
new_dist = -1.0
|
|
768
|
+
|
|
769
|
+
# "c" denotes clusters
|
|
770
|
+
c0_idx = None
|
|
771
|
+
c1_idx = None
|
|
772
|
+
|
|
773
|
+
while new_dist - old_dist >= 1e-4:
|
|
774
|
+
# Notice that label_repr and centroids.T have been normalized
|
|
775
|
+
# Thus, dist indicates the cosine similarity between points and centroids.
|
|
776
|
+
dist = label_repr @ centroids.T # shape: (n, 2)
|
|
777
|
+
|
|
778
|
+
# generate clusters
|
|
779
|
+
# let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1
|
|
780
|
+
k = len(cluster) // 2
|
|
781
|
+
c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k)
|
|
782
|
+
c0_idx = c_idx[:k]
|
|
783
|
+
c1_idx = c_idx[k:]
|
|
784
|
+
|
|
785
|
+
# update distances
|
|
786
|
+
# the new distance is the average of in-cluster distances to the centroids
|
|
787
|
+
old_dist = new_dist
|
|
788
|
+
new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster)
|
|
789
|
+
|
|
790
|
+
# update centroids
|
|
791
|
+
# the new centroid is the normalized average of the points in the cluster
|
|
792
|
+
centroids = normalize(
|
|
793
|
+
np.asarray(
|
|
794
|
+
[
|
|
795
|
+
np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))),
|
|
796
|
+
np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))),
|
|
797
|
+
]
|
|
798
|
+
)
|
|
799
|
+
)
|
|
800
|
+
return cluster[c0_idx], cluster[c1_idx]
|
|
@@ -155,7 +155,7 @@ class MultiLabelModel(L.LightningModule):
|
|
|
155
155
|
Returns:
|
|
156
156
|
dict: Top k label indexes and the prediction scores.
|
|
157
157
|
"""
|
|
158
|
-
|
|
158
|
+
pred_logits = self(batch)
|
|
159
159
|
pred_scores = pred_logits.detach().cpu().numpy()
|
|
160
160
|
k = self.save_k_predictions
|
|
161
161
|
top_k_idx = argsort_top_k(pred_scores, k, axis=1)
|
|
@@ -163,6 +163,10 @@ class MultiLabelModel(L.LightningModule):
|
|
|
163
163
|
|
|
164
164
|
return {"top_k_pred": top_k_idx, "top_k_pred_scores": top_k_scores}
|
|
165
165
|
|
|
166
|
+
def forward(self, batch):
|
|
167
|
+
"""compute predicted logits"""
|
|
168
|
+
return self.network(batch)["logits"]
|
|
169
|
+
|
|
166
170
|
def print(self, *args, **kwargs):
|
|
167
171
|
"""Prints only from process 0 and not in silent mode. Use this in any
|
|
168
172
|
distributed mode to log only once."""
|
|
@@ -178,7 +182,6 @@ class Model(MultiLabelModel):
|
|
|
178
182
|
Args:
|
|
179
183
|
classes (list): List of class names.
|
|
180
184
|
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
|
|
181
|
-
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
|
|
182
185
|
network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
|
|
183
186
|
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
|
|
184
187
|
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
|
|
@@ -189,7 +192,6 @@ class Model(MultiLabelModel):
|
|
|
189
192
|
self,
|
|
190
193
|
classes,
|
|
191
194
|
word_dict,
|
|
192
|
-
embed_vecs,
|
|
193
195
|
network,
|
|
194
196
|
loss_function="binary_cross_entropy_with_logits",
|
|
195
197
|
log_path=None,
|
|
@@ -200,7 +202,6 @@ class Model(MultiLabelModel):
|
|
|
200
202
|
ignore=["log_path"]
|
|
201
203
|
) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir).
|
|
202
204
|
self.word_dict = word_dict
|
|
203
|
-
self.embed_vecs = embed_vecs
|
|
204
205
|
self.classes = classes
|
|
205
206
|
self.network = network
|
|
206
207
|
self.configure_loss_function(loss_function)
|
|
@@ -224,8 +225,7 @@ class Model(MultiLabelModel):
|
|
|
224
225
|
pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
|
|
225
226
|
"""
|
|
226
227
|
target_labels = batch["label"]
|
|
227
|
-
|
|
228
|
-
pred_logits = outputs["logits"]
|
|
228
|
+
pred_logits = self(batch)
|
|
229
229
|
loss = self.loss_function(pred_logits, target_labels.float())
|
|
230
230
|
|
|
231
231
|
return loss, pred_logits
|
|
@@ -9,6 +9,7 @@ from .labelwise_attention_networks import BiGRULWAN
|
|
|
9
9
|
from .labelwise_attention_networks import BiLSTMLWAN
|
|
10
10
|
from .labelwise_attention_networks import BiLSTMLWMHAN
|
|
11
11
|
from .labelwise_attention_networks import CNNLWAN
|
|
12
|
+
from .labelwise_attention_networks import AttentionXML_0, AttentionXML_1
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def get_init_weight_func(init_weight):
|
|
@@ -10,6 +10,8 @@ from .modules import (
|
|
|
10
10
|
LabelwiseAttention,
|
|
11
11
|
LabelwiseMultiHeadAttention,
|
|
12
12
|
LabelwiseLinearOutput,
|
|
13
|
+
PartialLabelwiseAttention,
|
|
14
|
+
MultilayerLinearOutput,
|
|
13
15
|
)
|
|
14
16
|
|
|
15
17
|
|
|
@@ -266,3 +268,70 @@ class CNNLWAN(LabelwiseAttentionNetwork):
|
|
|
266
268
|
x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim)
|
|
267
269
|
x = self.output(x) # (batch_size, num_classes)
|
|
268
270
|
return {"logits": x}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class AttentionXML_0(nn.Module):
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
embed_vecs,
|
|
277
|
+
num_classes: int,
|
|
278
|
+
rnn_dim: int,
|
|
279
|
+
linear_size: list,
|
|
280
|
+
freeze_embed_training: bool = False,
|
|
281
|
+
rnn_layers: int = 1,
|
|
282
|
+
embed_dropout: float = 0.2,
|
|
283
|
+
encoder_dropout: float = 0,
|
|
284
|
+
post_encoder_dropout: float = 0.5,
|
|
285
|
+
):
|
|
286
|
+
super().__init__()
|
|
287
|
+
self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout)
|
|
288
|
+
self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout)
|
|
289
|
+
self.attention = LabelwiseAttention(rnn_dim, num_classes)
|
|
290
|
+
self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1)
|
|
291
|
+
|
|
292
|
+
def forward(self, inputs):
|
|
293
|
+
x = inputs["text"]
|
|
294
|
+
# the index of padding is 0
|
|
295
|
+
masks = x != 0
|
|
296
|
+
lengths = masks.sum(dim=1)
|
|
297
|
+
masks = masks[:, : lengths.max()]
|
|
298
|
+
|
|
299
|
+
x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size
|
|
300
|
+
x = self.encoder(x, lengths) # batch_size, length, hidden_size
|
|
301
|
+
x, _ = self.attention(x) # batch_size, num_classes, hidden_size
|
|
302
|
+
x = self.output(x) # batch_size, num_classes
|
|
303
|
+
return {"logits": x}
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class AttentionXML_1(nn.Module):
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
embed_vecs,
|
|
310
|
+
num_classes: int,
|
|
311
|
+
rnn_dim: int,
|
|
312
|
+
linear_size: list,
|
|
313
|
+
freeze_embed_training: bool = False,
|
|
314
|
+
rnn_layers: int = 1,
|
|
315
|
+
embed_dropout: float = 0.2,
|
|
316
|
+
encoder_dropout: float = 0,
|
|
317
|
+
post_encoder_dropout: float = 0.5,
|
|
318
|
+
):
|
|
319
|
+
super().__init__()
|
|
320
|
+
self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout)
|
|
321
|
+
self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout)
|
|
322
|
+
self.attention = PartialLabelwiseAttention(rnn_dim, num_classes)
|
|
323
|
+
self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1)
|
|
324
|
+
|
|
325
|
+
def forward(self, inputs):
|
|
326
|
+
x = inputs["text"]
|
|
327
|
+
labels_selected = inputs["labels_selected"]
|
|
328
|
+
# the index of padding is 0
|
|
329
|
+
masks = x != 0
|
|
330
|
+
lengths = masks.sum(dim=1)
|
|
331
|
+
masks = masks[:, : lengths.max()]
|
|
332
|
+
|
|
333
|
+
x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size
|
|
334
|
+
x = self.encoder(x, lengths) # batch_size, length, hidden_size
|
|
335
|
+
x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size
|
|
336
|
+
x = self.output(x) # batch_size, sample_size
|
|
337
|
+
return {"logits": x}
|
|
@@ -210,3 +210,35 @@ class LabelwiseLinearOutput(nn.Module):
|
|
|
210
210
|
|
|
211
211
|
def forward(self, input):
|
|
212
212
|
return (self.output.weight * input).sum(dim=-1) + self.output.bias
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class PartialLabelwiseAttention(nn.Module):
|
|
216
|
+
"""Similar to LabelwiseAttention.
|
|
217
|
+
What makes the class different from LabelwiseAttention is that only the weights of selected labels will be
|
|
218
|
+
updated in a single iteration.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, hidden_size, num_classes):
|
|
222
|
+
super().__init__()
|
|
223
|
+
self.attention = nn.Embedding(num_classes + 1, hidden_size)
|
|
224
|
+
|
|
225
|
+
def forward(self, inputs, labels_selected):
|
|
226
|
+
attn_inputs = inputs.transpose(1, 2) # batch_size, hidden_dim, length
|
|
227
|
+
attn_weights = self.attention(labels_selected) # batch_size, sample_size, hidden_dim
|
|
228
|
+
attention = attn_weights @ attn_inputs # batch_size, sample_size, length
|
|
229
|
+
attention = F.softmax(attention, -1) # batch_size, sample_size, length
|
|
230
|
+
logits = attention @ inputs # batch_size, sample_size, hidden_dim
|
|
231
|
+
return logits, attention
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class MultilayerLinearOutput(nn.Module):
|
|
235
|
+
def __init__(self, linear_size, output_size):
|
|
236
|
+
super().__init__()
|
|
237
|
+
self.linears = nn.ModuleList(nn.Linear(in_s, out_s) for in_s, out_s in zip(linear_size[:-1], linear_size[1:]))
|
|
238
|
+
self.output = nn.Linear(linear_size[-1], output_size)
|
|
239
|
+
|
|
240
|
+
def forward(self, inputs):
|
|
241
|
+
linear_out = inputs
|
|
242
|
+
for linear in self.linears:
|
|
243
|
+
linear_out = F.relu(linear(linear_out))
|
|
244
|
+
return torch.squeeze(self.output(linear_out), -1)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: libmultilabel
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: A library for multi-label
|
|
3
|
+
Version: 0.7.0
|
|
4
|
+
Summary: A library for multi-class and multi-label classification
|
|
5
5
|
Home-page: https://github.com/ASUS-AICS/LibMultiLabel
|
|
6
6
|
Author: LibMultiLabel Team
|
|
7
7
|
License: MIT License
|
|
@@ -18,6 +18,7 @@ libmultilabel/linear/preprocessor.py
|
|
|
18
18
|
libmultilabel/linear/tree.py
|
|
19
19
|
libmultilabel/linear/utils.py
|
|
20
20
|
libmultilabel/nn/__init__.py
|
|
21
|
+
libmultilabel/nn/attentionxml.py
|
|
21
22
|
libmultilabel/nn/data_utils.py
|
|
22
23
|
libmultilabel/nn/metrics.py
|
|
23
24
|
libmultilabel/nn/model.py
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
[metadata]
|
|
2
2
|
name = libmultilabel
|
|
3
|
-
version = 0.
|
|
3
|
+
version = 0.7.0
|
|
4
4
|
author = LibMultiLabel Team
|
|
5
5
|
license = MIT License
|
|
6
6
|
license_file = LICENSE
|
|
7
|
-
description = A library for multi-label
|
|
7
|
+
description = A library for multi-class and multi-label classification
|
|
8
8
|
long_description = See documentation here: https://www.csie.ntu.edu.tw/~cjlin/libmultilabel
|
|
9
9
|
url = https://github.com/ASUS-AICS/LibMultiLabel
|
|
10
10
|
project_urls =
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|