libmultilabel 0.6.2__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/PKG-INFO +2 -2
  2. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/README.md +1 -1
  3. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/linear.py +4 -1
  4. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/tree.py +3 -1
  5. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/utils.py +1 -0
  6. libmultilabel-0.7.0/libmultilabel/nn/attentionxml.py +800 -0
  7. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/model.py +6 -6
  8. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/__init__.py +1 -0
  9. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/labelwise_attention_networks.py +69 -0
  10. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/modules.py +32 -0
  11. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/nn_utils.py +0 -1
  12. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/PKG-INFO +2 -2
  13. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/SOURCES.txt +1 -0
  14. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/setup.cfg +2 -2
  15. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/LICENSE +0 -0
  16. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/__init__.py +0 -0
  17. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/common_utils.py +0 -0
  18. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/__init__.py +0 -0
  19. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/data_utils.py +0 -0
  20. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/metrics.py +0 -0
  21. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/linear/preprocessor.py +0 -0
  22. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/logging.py +0 -0
  23. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/__init__.py +0 -0
  24. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/data_utils.py +0 -0
  25. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/metrics.py +0 -0
  26. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/bert.py +0 -0
  27. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/bert_attention.py +0 -0
  28. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/caml.py +0 -0
  29. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/kim_cnn.py +0 -0
  30. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel/nn/networks/xml_cnn.py +0 -0
  31. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/dependency_links.txt +0 -0
  32. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/requires.txt +0 -0
  33. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/libmultilabel.egg-info/top_level.txt +0 -0
  34. {libmultilabel-0.6.2 → libmultilabel-0.7.0}/pyproject.toml +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: libmultilabel
3
- Version: 0.6.2
4
- Summary: A library for multi-label text classification
3
+ Version: 0.7.0
4
+ Summary: A library for multi-class and multi-label classification
5
5
  Home-page: https://github.com/ASUS-AICS/LibMultiLabel
6
6
  Author: LibMultiLabel Team
7
7
  License: MIT License
@@ -1,4 +1,4 @@
1
- # LibMultiLabel — a Library for Multi-class and Multi-label Text Classification
1
+ # LibMultiLabel — a Library for Multi-class and Multi-label Classification
2
2
 
3
3
  LibMultiLabel is a library for binary, multi-class, and multi-label classification. It has the following functionalities
4
4
 
@@ -17,10 +17,13 @@ __all__ = [
17
17
  "predict_values",
18
18
  "get_topk_labels",
19
19
  "get_positive_labels",
20
+ "FlatModel",
20
21
  ]
21
22
 
22
23
 
23
24
  class FlatModel:
25
+ """A model returned from a training function."""
26
+
24
27
  def __init__(
25
28
  self,
26
29
  name: str,
@@ -619,7 +622,7 @@ def train_binary_and_multiclass(
619
622
 
620
623
 
621
624
  def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
622
- """Calculates the decision values associated with x.
625
+ """Calculates the decision values associated with x, equivalent to model.predict_values(x).
623
626
 
624
627
  Args:
625
628
  model: A model returned from a training function.
@@ -10,7 +10,7 @@ from tqdm import tqdm
10
10
 
11
11
  from . import linear
12
12
 
13
- __all__ = ["train_tree"]
13
+ __all__ = ["train_tree", "TreeModel"]
14
14
 
15
15
 
16
16
  class Node:
@@ -38,6 +38,8 @@ class Node:
38
38
 
39
39
 
40
40
  class TreeModel:
41
+ """A model returned from train_tree."""
42
+
41
43
  def __init__(
42
44
  self,
43
45
  root: Node,
@@ -48,6 +48,7 @@ def save_pipeline(checkpoint_dir: str, preprocessor: Preprocessor, model):
48
48
  "model": model,
49
49
  },
50
50
  f,
51
+ protocol=pickle.HIGHEST_PROTOCOL,
51
52
  )
52
53
 
53
54
 
@@ -0,0 +1,800 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import Generator, Sequence, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from lightning import Trainer
11
+ from numpy import ndarray
12
+ from scipy.sparse import csr_matrix, csc_matrix, issparse
13
+ from scipy.special import expit
14
+ from sklearn.preprocessing import MultiLabelBinarizer, normalize
15
+ from torch import Tensor, is_tensor
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm import tqdm
19
+ from lightning.pytorch.callbacks import ModelCheckpoint
20
+
21
+ from .data_utils import UNK
22
+ from ..common_utils import dump_log
23
+ from ..linear.preprocessor import Preprocessor
24
+ from ..nn import networks
25
+ from ..nn.model import Model
26
+
27
+ __all__ = ["PLTTrainer"]
28
+
29
+ from .nn_utils import init_trainer, init_model
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class PLTTrainer:
35
+ CHECKPOINT_NAME = "model_"
36
+
37
+ def __init__(
38
+ self,
39
+ config,
40
+ classes: Optional[list] = None,
41
+ embed_vecs: Optional[Tensor] = None,
42
+ word_dict: Optional[dict] = None,
43
+ ):
44
+ # The number of levels is set to 2. In other words, there will be 2 models
45
+ self.multiclass = config.multiclass
46
+ if self.multiclass:
47
+ raise ValueError(
48
+ "The label space of multi-class datasets is usually not large, so PLT training is unnecessary."
49
+ "Please consider other methods."
50
+ "If you have a multi-class set with numerous labels, please let us know"
51
+ )
52
+
53
+ # cluster
54
+ self.cluster_size = config.cluster_size
55
+ # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training
56
+ self.beam_width = config.beam_width
57
+
58
+ # dataset meta info
59
+ self.embed_vecs = embed_vecs
60
+ self.word_dict = word_dict
61
+ self.classes = classes
62
+ self.max_seq_length = config.max_seq_length
63
+ self.num_classes = len(classes)
64
+
65
+ # multilabel binarizer fitted to the datasets
66
+ self.binarizer = None
67
+
68
+ # cluster meta info
69
+ self.cluster_size = config.cluster_size
70
+
71
+ # network parameters
72
+ self.network_config = config.network_config
73
+ self.init_weight = "xavier_uniform" # AttentionXML-specific setting
74
+ self.loss_function = config.loss_function
75
+
76
+ # optimizer parameters
77
+ self.optimizer = config.optimizer
78
+ self.learning_rate = config.learning_rate
79
+ self.momentum = config.momentum
80
+ self.weight_decay = config.weight_decay
81
+ # learning rate scheduler
82
+ self.lr_scheduler = config.lr_scheduler
83
+ self.scheduler_config = config.scheduler_config
84
+
85
+ # Trainer parameters
86
+ self.use_cpu = config.cpu
87
+ self.accelerator = "cpu" if self.use_cpu else "gpu"
88
+ self.devices = 1
89
+ self.num_nodes = 1
90
+ self.epochs = config.epochs
91
+ self.limit_train_batches = config.limit_train_batches
92
+ self.limit_val_batches = config.limit_val_batches
93
+ self.limit_test_batches = config.limit_test_batches
94
+
95
+ # callbacks
96
+ self.silent = config.silent
97
+ # EarlyStopping
98
+ self.early_stopping_metric = config.early_stopping_metric
99
+ self.patience = config.patience
100
+ # ModelCheckpoint
101
+ self.val_metric = config.val_metric
102
+ self.checkpoint_dir = Path(config.checkpoint_dir)
103
+
104
+ self.metrics = config.monitor_metrics
105
+ self.metric_threshold = config.metric_threshold
106
+ self.monitor_metrics = config.monitor_metrics
107
+
108
+ # dataloader parameters
109
+ # whether shuffle the training dataset or not during the training process
110
+ self.shuffle = config.shuffle
111
+ pin_memory = True if self.accelerator == "gpu" else False
112
+ # training DataLoader
113
+ self.dataloader = partial(
114
+ DataLoader,
115
+ batch_size=config.batch_size,
116
+ num_workers=config.data_workers,
117
+ pin_memory=pin_memory,
118
+ )
119
+ # evaluation DataLoader
120
+ self.eval_dataloader = partial(
121
+ self.dataloader,
122
+ batch_size=config.eval_batch_size,
123
+ )
124
+
125
+ # predict
126
+ self.save_k_predictions = config.save_k_predictions
127
+
128
+ # save path
129
+ self.log_path = config.log_path
130
+ self.predict_out_path = config.predict_out_path
131
+ self.config = config
132
+
133
+ def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]:
134
+ """Map labels to their corresponding clusters in CSR sparse format.
135
+ Notice that this function deals with SPARSE matrix.
136
+ Assume there are 6 labels clustered as [(0, 1), (2, 3), (4, 5)]. Here (0, 1) is cluster with index 0 and so on.
137
+ Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2].
138
+
139
+ Args:
140
+ cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
141
+ *labels (csr_matrix): labels in CSR sparse format.
142
+
143
+ Returns:
144
+ Generator[csr_matrix]: resulting clusters converted from labels in CSR sparse format
145
+ """
146
+ mapping = np.empty(self.num_classes, dtype=np.uint32)
147
+ for idx, clusters in enumerate(cluster_mapping):
148
+ mapping[clusters] = idx
149
+
150
+ def _label2cluster(label: csr_matrix) -> csr_matrix:
151
+ row = []
152
+ col = []
153
+ data = []
154
+ for i in range(label.shape[0]):
155
+ # n include all mapped ancestor clusters
156
+ n = np.unique(mapping[label.indices[label.indptr[i] : label.indptr[i + 1]]])
157
+ row += [i] * len(n)
158
+ col += n.tolist()
159
+ data += [1] * len(n)
160
+ return csr_matrix((data, (row, col)), shape=(label.shape[0], len(cluster_mapping)))
161
+
162
+ return (_label2cluster(label) for label in labels)
163
+
164
+ @staticmethod
165
+ def cluster2label(cluster_mapping, clusters, cluster_scores=None):
166
+ """Expand clusters to their corresponding labels and, if available, assign scores to each label.
167
+ Labels inside the same cluster have the same scores. This function is applied to predictions from model 0.
168
+ Notice that the behaviors of this function are different from label2cluster.
169
+ Also notice that this function deals with DENSE matrix.
170
+
171
+ Args:
172
+ cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
173
+ clusters (np.ndarray): predicted clusters from model 0.
174
+ cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0.
175
+
176
+ Returns:
177
+ Generator[np.ndarray]: resulting labels expanded from clusters
178
+ """
179
+
180
+ labels_selected = []
181
+
182
+ if cluster_scores is not None:
183
+ # label_scores are corresponding scores for selected labels and
184
+ # shape: (len(x), cluster_size * top_k)
185
+ label_scores = []
186
+ for score, cluster in zip(cluster_scores, clusters):
187
+ label_scores += [np.repeat(score, [len(labels) for labels in cluster_mapping[cluster]])]
188
+ labels_selected += [np.concatenate(cluster_mapping[cluster])]
189
+ return labels_selected, label_scores
190
+ else:
191
+ labels_selected = [np.concatenate(cluster_mapping[cluster]) for cluster in clusters]
192
+ return labels_selected
193
+
194
+ def fit(self, datasets):
195
+ """fit model to the training dataset
196
+
197
+ Args:
198
+ datasets: dict containing training, validation, and/or test datasets
199
+ """
200
+ if self.get_best_model_path(level=1).exists():
201
+ return
202
+
203
+ # AttentionXML-specific data preprocessing
204
+ train_val_dataset = datasets["train"] + datasets["val"]
205
+ train_val_dataset = {
206
+ "x": [" ".join(i["text"]) for i in train_val_dataset],
207
+ "y": [i["label"] for i in train_val_dataset],
208
+ }
209
+
210
+ # Preprocessor does tf-idf vectorization and multilabel binarization
211
+ # For details, see libmultilabel.linear.preprocessor.Preprocessor
212
+ preprocessor = Preprocessor()
213
+ datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes}
214
+ # Preprocessor requires the input dictionary to has a key named "train" and will return a new dictionary with
215
+ # the same key.
216
+ train_val_dataset_tf = preprocessor.fit_transform(datasets_temp)["train"]
217
+ # save binarizer for testing
218
+ self.binarizer = preprocessor.binarizer
219
+
220
+ train_x = self.reformat_text(datasets["train"])
221
+ val_x = self.reformat_text(datasets["val"])
222
+
223
+ train_y = train_val_dataset_tf["y"][: len(datasets["train"])]
224
+ val_y = train_val_dataset_tf["y"][len(datasets["train"]) :]
225
+
226
+ # clusters are saved to the disk so that users doesn't need to provide the original training data when they want
227
+ # to do predicting solely
228
+ build_label_tree(
229
+ sparse_x=train_val_dataset_tf["x"],
230
+ sparse_y=train_val_dataset_tf["y"],
231
+ cluster_size=self.cluster_size,
232
+ output_dir=self.checkpoint_dir,
233
+ )
234
+
235
+ clusters = np.load(self.get_cluster_path(), allow_pickle=True)
236
+
237
+ # map each y to the parent cluster indices
238
+ train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y)
239
+
240
+ trainer = init_trainer(
241
+ self.checkpoint_dir,
242
+ epochs=self.epochs,
243
+ patience=self.patience,
244
+ early_stopping_metric=self.early_stopping_metric,
245
+ val_metric=self.val_metric,
246
+ silent=self.silent,
247
+ use_cpu=self.use_cpu,
248
+ limit_train_batches=self.limit_train_batches,
249
+ limit_val_batches=self.limit_val_batches,
250
+ limit_test_batches=self.limit_test_batches,
251
+ save_checkpoints=True,
252
+ )
253
+ trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}0"
254
+
255
+ train_dataloader = self.dataloader(PlainDataset(train_x, train_y_clustered), shuffle=self.shuffle)
256
+ val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered))
257
+
258
+ best_model_path = self.get_best_model_path(level=0)
259
+ if not best_model_path.exists():
260
+ model_0 = init_model(
261
+ model_name="AttentionXML_0",
262
+ network_config=self.network_config,
263
+ classes=clusters,
264
+ word_dict=self.word_dict,
265
+ embed_vecs=self.embed_vecs,
266
+ init_weight=self.init_weight,
267
+ log_path=self.log_path,
268
+ learning_rate=self.learning_rate,
269
+ optimizer=self.optimizer,
270
+ momentum=self.momentum,
271
+ weight_decay=self.weight_decay,
272
+ lr_scheduler=self.lr_scheduler,
273
+ scheduler_config=self.scheduler_config,
274
+ val_metric=self.val_metric,
275
+ metric_threshold=self.metric_threshold,
276
+ monitor_metrics=self.monitor_metrics,
277
+ multiclass=self.multiclass,
278
+ loss_function=self.loss_function,
279
+ silent=self.silent,
280
+ save_k_predictions=self.beam_width,
281
+ )
282
+
283
+ logger.info(f"Training level 0. Number of clusters: {len(clusters)}")
284
+ trainer.fit(model_0, train_dataloader, val_dataloader)
285
+ logger.info(f"Finish training level 0")
286
+
287
+ logger.info(f"Best model loaded from {best_model_path}")
288
+ model_0 = Model.load_from_checkpoint(best_model_path)
289
+
290
+ logger.info(
291
+ f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and "
292
+ f"extract labels from these clusters for level 1 training."
293
+ )
294
+ # load training and validation data and predict corresponding level 0 clusters
295
+ train_dataloader = self.dataloader(PlainDataset(train_x))
296
+ val_dataloader = self.dataloader(PlainDataset(val_x))
297
+
298
+ train_pred = trainer.predict(model_0, train_dataloader)
299
+ val_pred = trainer.predict(model_0, val_dataloader)
300
+
301
+ train_clusters_pred = np.vstack([i["top_k_pred"] for i in train_pred])
302
+ val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred]))
303
+ val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred])
304
+
305
+ train_clusters_selected = np.empty((len(train_x), self.beam_width), dtype=np.uint)
306
+ for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")):
307
+ # relevant clusters are positive
308
+ pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]])
309
+ # Select relevant clusters first. Then from top-predicted clusters, sequentially include them until
310
+ # cluster number reaches beam width
311
+ if len(pos) <= self.beam_width:
312
+ selected = pos
313
+ for y in ys:
314
+ y = y.item()
315
+ if len(selected) == self.beam_width:
316
+ break
317
+ selected.add(y)
318
+ # Regard positive (true) label as samples iff they appear in the predicted labels
319
+ # if the number of positive labels is more than top_k. If samples are not of length top_k
320
+ # add unseen predicted labels until reaching top_k.
321
+ else:
322
+ selected = set()
323
+ for y in ys:
324
+ y = y.item()
325
+ if y in pos:
326
+ selected.add(y)
327
+ if len(selected) == self.beam_width:
328
+ break
329
+ if len(selected) < self.beam_width:
330
+ selected = (list(selected) + list(pos - selected))[: self.beam_width]
331
+ train_clusters_selected[i] = np.asarray(list(selected))
332
+
333
+ train_labels_selected = PLTTrainer.cluster2label(clusters, train_clusters_selected)
334
+ val_labels_pred, val_scores_pred = PLTTrainer.cluster2label(clusters, val_clusters_pred, val_scores_pred)
335
+ num_labels_selected = self.beam_width * max(len(c) for c in clusters)
336
+
337
+ trainer = init_trainer(
338
+ self.checkpoint_dir,
339
+ epochs=self.epochs,
340
+ patience=self.patience,
341
+ early_stopping_metric=self.val_metric,
342
+ val_metric=self.val_metric,
343
+ silent=self.silent,
344
+ use_cpu=self.use_cpu,
345
+ limit_train_batches=self.limit_train_batches,
346
+ limit_val_batches=self.limit_val_batches,
347
+ limit_test_batches=self.limit_test_batches,
348
+ save_checkpoints=True,
349
+ )
350
+ trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}1"
351
+
352
+ # train & val dataloaders for training
353
+ train_dataloader = self.dataloader(
354
+ PLTDataset(
355
+ train_x,
356
+ train_y,
357
+ num_classes=self.num_classes,
358
+ num_labels_selected=num_labels_selected,
359
+ labels_selected=train_labels_selected,
360
+ ),
361
+ shuffle=self.shuffle,
362
+ )
363
+ val_dataloader = self.dataloader(
364
+ PLTDataset(
365
+ val_x,
366
+ val_y,
367
+ num_classes=self.num_classes,
368
+ num_labels_selected=num_labels_selected,
369
+ labels_selected=val_labels_pred,
370
+ label_scores=val_scores_pred,
371
+ ),
372
+ )
373
+
374
+ try:
375
+ network = getattr(networks, "AttentionXML_1")(
376
+ embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config)
377
+ )
378
+ except Exception:
379
+ raise AttributeError("Failed to initialize AttentionXML")
380
+
381
+ model_1 = PLTModel(
382
+ classes=self.classes,
383
+ word_dict=self.word_dict,
384
+ network=network,
385
+ log_path=self.log_path,
386
+ learning_rate=self.learning_rate,
387
+ optimizer=self.optimizer,
388
+ momentum=self.momentum,
389
+ weight_decay=self.weight_decay,
390
+ lr_scheduler=self.lr_scheduler,
391
+ scheduler_config=self.scheduler_config,
392
+ val_metric=self.val_metric,
393
+ metric_threshold=self.metric_threshold,
394
+ monitor_metrics=self.monitor_metrics,
395
+ multiclass=self.multiclass,
396
+ loss_function=self.loss_function,
397
+ silent=self.silent,
398
+ save_k_predictions=self.save_k_predictions,
399
+ )
400
+ logger.info(f"Initialize model with weights from level 0")
401
+ # For weights not initialized by the level-0 model, use xavier uniform initialization
402
+ torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight)
403
+ # As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately
404
+ model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict())
405
+ model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict())
406
+ model_1.network.output.load_state_dict(model_0.network.output.state_dict())
407
+
408
+ del model_0
409
+
410
+ logger.info(
411
+ f"Training level 1. Number of labels: {self.num_classes}."
412
+ f"Number of labels selected: {train_dataloader.dataset.num_labels_selected}"
413
+ )
414
+ trainer.fit(model_1, train_dataloader, val_dataloader)
415
+ logger.info(f"Best model loaded from {best_model_path}")
416
+ logger.info(f"Finish training level 1")
417
+
418
+ def test(self, dataset):
419
+ # retrieve word_dict from model_1
420
+ # prediction starts from level 0
421
+ model_0 = Model.load_from_checkpoint(
422
+ self.get_best_model_path(level=0),
423
+ save_k_predictions=self.beam_width,
424
+ )
425
+ model_1 = PLTModel.load_from_checkpoint(
426
+ self.get_best_model_path(level=1),
427
+ save_k_predictions=self.save_k_predictions,
428
+ metrics=self.metrics,
429
+ )
430
+ self.word_dict = model_1.word_dict
431
+ classes = model_1.classes
432
+
433
+ test_x = self.reformat_text(dataset)
434
+
435
+ if self.binarizer is None:
436
+ binarizer = MultiLabelBinarizer(classes=classes, sparse_output=True)
437
+ binarizer.fit(None)
438
+ test_y = binarizer.transform((i["label"] for i in dataset))
439
+ else:
440
+ test_y = self.binarizer.transform((i["label"] for i in dataset))
441
+ logger.info("Testing process started")
442
+ trainer = Trainer(
443
+ devices=1,
444
+ accelerator=self.accelerator,
445
+ logger=False,
446
+ )
447
+
448
+ test_dataloader = self.eval_dataloader(PlainDataset(test_x))
449
+
450
+ logger.info(f"Predicting level 0. Number of clusters: {self.beam_width}")
451
+ test_pred = trainer.predict(model_0, test_dataloader)
452
+ test_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred]))
453
+ test_clusters_pred = np.vstack([i["top_k_pred"] for i in test_pred])
454
+
455
+ clusters = np.load(self.get_cluster_path(), allow_pickle=True)
456
+ test_labels_pred, test_scores_pred = PLTTrainer.cluster2label(clusters, test_clusters_pred, test_scores_pred)
457
+ num_labels_selected = self.beam_width * max(len(c) for c in clusters)
458
+
459
+ test_dataloader = self.eval_dataloader(
460
+ PLTDataset(
461
+ test_x,
462
+ test_y,
463
+ num_classes=self.num_classes,
464
+ num_labels_selected=num_labels_selected,
465
+ labels_selected=test_labels_pred,
466
+ label_scores=test_scores_pred,
467
+ ),
468
+ )
469
+
470
+ logger.info(f"Testing level 1")
471
+ trainer.test(model_1, test_dataloader)
472
+ logger.info("Testing process finished")
473
+
474
+ if self.save_k_predictions > 0:
475
+ batch_predictions = trainer.predict(model_1, test_dataloader)
476
+ pred_labels = np.vstack([batch["top_k_pred"] for batch in batch_predictions])
477
+ pred_scores = np.vstack([batch["top_k_pred_scores"] for batch in batch_predictions])
478
+ with open(self.predict_out_path, "w") as fp:
479
+ for pred_label, pred_score in zip(pred_labels, pred_scores):
480
+ out_str = " ".join(
481
+ [f"{model_1.classes[label]}:{score:.4}" for label, score in zip(pred_label, pred_score)]
482
+ )
483
+ fp.write(out_str + "\n")
484
+ logging.info(f"Saved predictions to: {self.predict_out_path}")
485
+
486
+ dump_log(self.log_path, config=self.config)
487
+
488
+ def reformat_text(self, dataset):
489
+ # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
490
+ encoded_text = list(
491
+ map(
492
+ lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)
493
+ if text
494
+ else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
495
+ [instance["text"][: self.max_seq_length] for instance in dataset],
496
+ )
497
+ )
498
+ # pad the first entry to be of length 500 if necessary
499
+ encoded_text[0] = torch.cat(
500
+ (
501
+ encoded_text[0],
502
+ torch.tensor(0, dtype=torch.int64).repeat(self.max_seq_length - encoded_text[0].shape[0]),
503
+ )
504
+ )
505
+ encoded_text = pad_sequence(encoded_text, batch_first=True)
506
+ return encoded_text
507
+
508
+ def get_best_model_path(self, level: int) -> Path:
509
+ return self.checkpoint_dir / f"{self.CHECKPOINT_NAME}{level}{ModelCheckpoint.FILE_EXTENSION}"
510
+
511
+ def get_cluster_path(self) -> Path:
512
+ return self.checkpoint_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}"
513
+
514
+
515
+ ###################################### Model ######################################
516
+
517
+
518
+ class PLTModel(Model):
519
+ def __init__(
520
+ self,
521
+ classes,
522
+ word_dict,
523
+ network,
524
+ loss_function="binary_cross_entropy_with_logits",
525
+ log_path=None,
526
+ **kwargs,
527
+ ):
528
+ super().__init__(
529
+ classes=classes,
530
+ word_dict=word_dict,
531
+ network=network,
532
+ loss_function=loss_function,
533
+ log_path=log_path,
534
+ **kwargs,
535
+ )
536
+
537
+ def scatter_logits(
538
+ self,
539
+ logits: Tensor,
540
+ labels_selected: Tensor,
541
+ label_scores: Tensor,
542
+ ) -> Tensor:
543
+ """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to
544
+ the whole label space. The scores of unsampled labels are set to 0."""
545
+ src = torch.sigmoid(logits.detach()) * label_scores
546
+ # During validation/testing, many fake labels might exist in a batch for the purpose of padding.
547
+ # A fake label has index len(classes) and does not belong to the real label space.
548
+ preds = torch.zeros(
549
+ labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype
550
+ )
551
+ preds.scatter_(dim=1, index=labels_selected, src=src)
552
+ # slicing removes fake labels whose index is exactly len(self.classes)
553
+ # afterwards, preds is restored to the real label space
554
+ preds = preds[:, :-1]
555
+ return preds
556
+
557
+ def shared_step(self, batch):
558
+ """Return loss and predicted logits of the network.
559
+
560
+ Args:
561
+ batch (dict): A batch of text and label.
562
+
563
+ Returns:
564
+ loss (torch.Tensor): Loss between target and predict logits.
565
+ pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
566
+ """
567
+ y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1)
568
+ pred_logits = self(batch)
569
+ loss = self.loss_function(pred_logits, y.float())
570
+ return loss, pred_logits
571
+
572
+ def _shared_eval_step(self, batch, batch_idx):
573
+ logits = self(batch)
574
+ logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"])
575
+ self.eval_metric.update(logits, batch["label"].long())
576
+
577
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
578
+ logits = self(batch)
579
+ scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.save_k_predictions)
580
+ # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned
581
+ logits = torch.logit(scores)
582
+ return {
583
+ "top_k_pred": torch.take_along_dim(batch["labels_selected"], labels, dim=1).numpy(force=True),
584
+ "top_k_pred_scores": logits.numpy(force=True),
585
+ }
586
+
587
+
588
+ ###################################### Dataset ######################################
589
+
590
+
591
+ class PlainDataset(Dataset):
592
+ """Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset.
593
+ WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset
594
+ does not.
595
+ Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does
596
+ this while generating clusters. There is no need to do multilabel binarization again.
597
+
598
+ Args:
599
+ x (list | ndarray | Tensor): texts.
600
+ y (Optional: csr_matrix | ndarray | Tensor): labels.
601
+ """
602
+
603
+ def __init__(self, x, y=None):
604
+ if y is not None:
605
+ assert len(x) == y.shape[0], "Sizes mismatch between texts and labels"
606
+ self.x = x
607
+ self.y = y
608
+
609
+ def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]:
610
+ item = {"text": self.x[idx]}
611
+
612
+ # train/val/test
613
+ if self.y is not None:
614
+ if issparse(self.y):
615
+ y = self.y[idx].toarray().squeeze(0).astype(np.int32)
616
+ elif isinstance(self.y, ndarray):
617
+ y = self.y[idx].astype(np.int32)
618
+ elif is_tensor(self.y):
619
+ y = self.y[idx].int()
620
+ else:
621
+ raise TypeError(
622
+ "The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor."
623
+ f"But got {type(self.y)} instead."
624
+ )
625
+ item["label"] = y
626
+ return item
627
+
628
+ def __len__(self):
629
+ return len(self.x)
630
+
631
+
632
+ class PLTDataset(PlainDataset):
633
+ """Dataset for model_1 of AttentionXML.
634
+
635
+ Args:
636
+ x: texts.
637
+ y: labels.
638
+ num_classes: number of classes.
639
+ num_labels_selected: the number of selected labels.
640
+ labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k).
641
+ label_scores: scores for each label. Shape: (len(x), predict_top_k).
642
+ """
643
+
644
+ def __init__(
645
+ self,
646
+ x,
647
+ y: Optional[csr_matrix | ndarray] = None,
648
+ *,
649
+ num_classes: int,
650
+ num_labels_selected: int,
651
+ labels_selected: ndarray | Tensor,
652
+ label_scores: Optional[ndarray | Tensor] = None,
653
+ ):
654
+ super().__init__(x, y)
655
+ self.num_classes = num_classes
656
+ self.num_labels_selected = num_labels_selected
657
+ self.labels_selected = labels_selected
658
+ self.label_scores = label_scores
659
+
660
+ def __getitem__(self, idx: int):
661
+ item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])}
662
+
663
+ if self.y is not None:
664
+ item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32)
665
+
666
+ # PyTorch requires inputs to be of the same shape. Pad any instance with length below num_labels_selected by
667
+ # randomly selecting labels.
668
+ # training
669
+ if self.label_scores is None:
670
+ # randomly add real labels when the number is below num_labels_selected
671
+ # some labels might be selected more than once
672
+ if len(item["labels_selected"]) < self.num_labels_selected:
673
+ samples = np.random.randint(
674
+ self.num_classes,
675
+ size=self.num_labels_selected - len(item["labels_selected"]),
676
+ )
677
+ item["labels_selected"] = np.concatenate([item["labels_selected"], samples])
678
+
679
+ # val/test/pred
680
+ else:
681
+ item["label_scores"] = self.label_scores[idx]
682
+ # add fake labels when the number of labels is below num_labels_selected
683
+ if len(item["labels_selected"]) < self.num_labels_selected:
684
+ item["label_scores"] = np.concatenate(
685
+ [
686
+ item["label_scores"],
687
+ [-np.inf] * (self.num_labels_selected - len(item["labels_selected"])),
688
+ ]
689
+ )
690
+ item["labels_selected"] = np.concatenate(
691
+ [
692
+ item["labels_selected"],
693
+ [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])),
694
+ ]
695
+ )
696
+ return item
697
+
698
+
699
+ ###################################### Cluster ######################################
700
+
701
+ CLUSTER_NAME = "label_clusters"
702
+ CLUSTER_FILE_EXTENSION = FILE_EXTENSION = ".npy"
703
+
704
+
705
+ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path):
706
+ """Build a binary tree to group labels into clusters, each of which contains up tp cluster_size labels. The tree has
707
+ several layers; nodes in the last layer correspond to the output clusters.
708
+ Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters look something like:
709
+ ((0, 2), (1, 3), (4, 5)).
710
+
711
+ Args:
712
+ sparse_x: features extracted from texts in CSR sparse format
713
+ sparse_y: binarized labels in CSR sparse format
714
+ cluster_size: the maximum number of labels within each cluster
715
+ output_dir: directory to store the clustering file
716
+ """
717
+ # skip constructing label tree if the output file already exists
718
+ output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir)
719
+ cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}"
720
+ if cluster_path.exists():
721
+ logger.info("Clustering has finished in a previous run")
722
+ return
723
+
724
+ # meta info
725
+ logger.info("Label clustering started")
726
+ logger.info(f"Cluster size: {cluster_size}")
727
+ # The height of the tree satisfies the following inequality:
728
+ # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size
729
+ height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size)))
730
+ logger.info(f"Labels will be grouped into {2 ** height} clusters")
731
+
732
+ output_dir.mkdir(parents=True, exist_ok=True)
733
+
734
+ # For each label, sum up normalized instances relevant to the label and normalize to get the label representation
735
+ label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x)))
736
+
737
+ # clustering by a binary tree:
738
+ # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters.
739
+ clusters = [np.arange(sparse_y.shape[1])]
740
+ for _ in range(height):
741
+ next_clusters = []
742
+ for cluster in clusters:
743
+ next_clusters += _split_cluster(cluster, label_repr[cluster])
744
+ clusters = next_clusters
745
+ logger.info(f"Having grouped {len(clusters)} clusters")
746
+
747
+ np.save(cluster_path, np.asarray(clusters, dtype=object))
748
+ logger.info(f"Label clustering finished. Saving results to {cluster_path}")
749
+
750
+
751
+ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]:
752
+ """A variant of KMeans implemented in AttentionXML. Here K = 2. The cluster is partitioned into two groups, each
753
+ with approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are:
754
+ 1. the distance metric is cosine similarity.
755
+ 2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids.
756
+
757
+ Args:
758
+ cluster: a subset of labels
759
+ label_repr: the normalized representations of the relationship between labels and texts of the given cluster
760
+ """
761
+ # Randomly choose two points as initial centroids and obtain their label representations
762
+ centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray()
763
+
764
+ # Initialize distances (cosine similarity)
765
+ # Cosine similarity always falls to the interval [-1, 1]
766
+ old_dist = -2.0
767
+ new_dist = -1.0
768
+
769
+ # "c" denotes clusters
770
+ c0_idx = None
771
+ c1_idx = None
772
+
773
+ while new_dist - old_dist >= 1e-4:
774
+ # Notice that label_repr and centroids.T have been normalized
775
+ # Thus, dist indicates the cosine similarity between points and centroids.
776
+ dist = label_repr @ centroids.T # shape: (n, 2)
777
+
778
+ # generate clusters
779
+ # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1
780
+ k = len(cluster) // 2
781
+ c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k)
782
+ c0_idx = c_idx[:k]
783
+ c1_idx = c_idx[k:]
784
+
785
+ # update distances
786
+ # the new distance is the average of in-cluster distances to the centroids
787
+ old_dist = new_dist
788
+ new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster)
789
+
790
+ # update centroids
791
+ # the new centroid is the normalized average of the points in the cluster
792
+ centroids = normalize(
793
+ np.asarray(
794
+ [
795
+ np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))),
796
+ np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))),
797
+ ]
798
+ )
799
+ )
800
+ return cluster[c0_idx], cluster[c1_idx]
@@ -155,7 +155,7 @@ class MultiLabelModel(L.LightningModule):
155
155
  Returns:
156
156
  dict: Top k label indexes and the prediction scores.
157
157
  """
158
- _, pred_logits = self.shared_step(batch)
158
+ pred_logits = self(batch)
159
159
  pred_scores = pred_logits.detach().cpu().numpy()
160
160
  k = self.save_k_predictions
161
161
  top_k_idx = argsort_top_k(pred_scores, k, axis=1)
@@ -163,6 +163,10 @@ class MultiLabelModel(L.LightningModule):
163
163
 
164
164
  return {"top_k_pred": top_k_idx, "top_k_pred_scores": top_k_scores}
165
165
 
166
+ def forward(self, batch):
167
+ """compute predicted logits"""
168
+ return self.network(batch)["logits"]
169
+
166
170
  def print(self, *args, **kwargs):
167
171
  """Prints only from process 0 and not in silent mode. Use this in any
168
172
  distributed mode to log only once."""
@@ -178,7 +182,6 @@ class Model(MultiLabelModel):
178
182
  Args:
179
183
  classes (list): List of class names.
180
184
  word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
181
- embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
182
185
  network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
183
186
  loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
184
187
  cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
@@ -189,7 +192,6 @@ class Model(MultiLabelModel):
189
192
  self,
190
193
  classes,
191
194
  word_dict,
192
- embed_vecs,
193
195
  network,
194
196
  loss_function="binary_cross_entropy_with_logits",
195
197
  log_path=None,
@@ -200,7 +202,6 @@ class Model(MultiLabelModel):
200
202
  ignore=["log_path"]
201
203
  ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir).
202
204
  self.word_dict = word_dict
203
- self.embed_vecs = embed_vecs
204
205
  self.classes = classes
205
206
  self.network = network
206
207
  self.configure_loss_function(loss_function)
@@ -224,8 +225,7 @@ class Model(MultiLabelModel):
224
225
  pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
225
226
  """
226
227
  target_labels = batch["label"]
227
- outputs = self.network(batch)
228
- pred_logits = outputs["logits"]
228
+ pred_logits = self(batch)
229
229
  loss = self.loss_function(pred_logits, target_labels.float())
230
230
 
231
231
  return loss, pred_logits
@@ -9,6 +9,7 @@ from .labelwise_attention_networks import BiGRULWAN
9
9
  from .labelwise_attention_networks import BiLSTMLWAN
10
10
  from .labelwise_attention_networks import BiLSTMLWMHAN
11
11
  from .labelwise_attention_networks import CNNLWAN
12
+ from .labelwise_attention_networks import AttentionXML_0, AttentionXML_1
12
13
 
13
14
 
14
15
  def get_init_weight_func(init_weight):
@@ -10,6 +10,8 @@ from .modules import (
10
10
  LabelwiseAttention,
11
11
  LabelwiseMultiHeadAttention,
12
12
  LabelwiseLinearOutput,
13
+ PartialLabelwiseAttention,
14
+ MultilayerLinearOutput,
13
15
  )
14
16
 
15
17
 
@@ -266,3 +268,70 @@ class CNNLWAN(LabelwiseAttentionNetwork):
266
268
  x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim)
267
269
  x = self.output(x) # (batch_size, num_classes)
268
270
  return {"logits": x}
271
+
272
+
273
+ class AttentionXML_0(nn.Module):
274
+ def __init__(
275
+ self,
276
+ embed_vecs,
277
+ num_classes: int,
278
+ rnn_dim: int,
279
+ linear_size: list,
280
+ freeze_embed_training: bool = False,
281
+ rnn_layers: int = 1,
282
+ embed_dropout: float = 0.2,
283
+ encoder_dropout: float = 0,
284
+ post_encoder_dropout: float = 0.5,
285
+ ):
286
+ super().__init__()
287
+ self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout)
288
+ self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout)
289
+ self.attention = LabelwiseAttention(rnn_dim, num_classes)
290
+ self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1)
291
+
292
+ def forward(self, inputs):
293
+ x = inputs["text"]
294
+ # the index of padding is 0
295
+ masks = x != 0
296
+ lengths = masks.sum(dim=1)
297
+ masks = masks[:, : lengths.max()]
298
+
299
+ x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size
300
+ x = self.encoder(x, lengths) # batch_size, length, hidden_size
301
+ x, _ = self.attention(x) # batch_size, num_classes, hidden_size
302
+ x = self.output(x) # batch_size, num_classes
303
+ return {"logits": x}
304
+
305
+
306
+ class AttentionXML_1(nn.Module):
307
+ def __init__(
308
+ self,
309
+ embed_vecs,
310
+ num_classes: int,
311
+ rnn_dim: int,
312
+ linear_size: list,
313
+ freeze_embed_training: bool = False,
314
+ rnn_layers: int = 1,
315
+ embed_dropout: float = 0.2,
316
+ encoder_dropout: float = 0,
317
+ post_encoder_dropout: float = 0.5,
318
+ ):
319
+ super().__init__()
320
+ self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout)
321
+ self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout)
322
+ self.attention = PartialLabelwiseAttention(rnn_dim, num_classes)
323
+ self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1)
324
+
325
+ def forward(self, inputs):
326
+ x = inputs["text"]
327
+ labels_selected = inputs["labels_selected"]
328
+ # the index of padding is 0
329
+ masks = x != 0
330
+ lengths = masks.sum(dim=1)
331
+ masks = masks[:, : lengths.max()]
332
+
333
+ x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size
334
+ x = self.encoder(x, lengths) # batch_size, length, hidden_size
335
+ x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size
336
+ x = self.output(x) # batch_size, sample_size
337
+ return {"logits": x}
@@ -210,3 +210,35 @@ class LabelwiseLinearOutput(nn.Module):
210
210
 
211
211
  def forward(self, input):
212
212
  return (self.output.weight * input).sum(dim=-1) + self.output.bias
213
+
214
+
215
+ class PartialLabelwiseAttention(nn.Module):
216
+ """Similar to LabelwiseAttention.
217
+ What makes the class different from LabelwiseAttention is that only the weights of selected labels will be
218
+ updated in a single iteration.
219
+ """
220
+
221
+ def __init__(self, hidden_size, num_classes):
222
+ super().__init__()
223
+ self.attention = nn.Embedding(num_classes + 1, hidden_size)
224
+
225
+ def forward(self, inputs, labels_selected):
226
+ attn_inputs = inputs.transpose(1, 2) # batch_size, hidden_dim, length
227
+ attn_weights = self.attention(labels_selected) # batch_size, sample_size, hidden_dim
228
+ attention = attn_weights @ attn_inputs # batch_size, sample_size, length
229
+ attention = F.softmax(attention, -1) # batch_size, sample_size, length
230
+ logits = attention @ inputs # batch_size, sample_size, hidden_dim
231
+ return logits, attention
232
+
233
+
234
+ class MultilayerLinearOutput(nn.Module):
235
+ def __init__(self, linear_size, output_size):
236
+ super().__init__()
237
+ self.linears = nn.ModuleList(nn.Linear(in_s, out_s) for in_s, out_s in zip(linear_size[:-1], linear_size[1:]))
238
+ self.output = nn.Linear(linear_size[-1], output_size)
239
+
240
+ def forward(self, inputs):
241
+ linear_out = inputs
242
+ for linear in self.linears:
243
+ linear_out = F.relu(linear(linear_out))
244
+ return torch.squeeze(self.output(linear_out), -1)
@@ -100,7 +100,6 @@ def init_model(
100
100
  model = Model(
101
101
  classes=classes,
102
102
  word_dict=word_dict,
103
- embed_vecs=embed_vecs,
104
103
  network=network,
105
104
  log_path=log_path,
106
105
  learning_rate=learning_rate,
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: libmultilabel
3
- Version: 0.6.2
4
- Summary: A library for multi-label text classification
3
+ Version: 0.7.0
4
+ Summary: A library for multi-class and multi-label classification
5
5
  Home-page: https://github.com/ASUS-AICS/LibMultiLabel
6
6
  Author: LibMultiLabel Team
7
7
  License: MIT License
@@ -18,6 +18,7 @@ libmultilabel/linear/preprocessor.py
18
18
  libmultilabel/linear/tree.py
19
19
  libmultilabel/linear/utils.py
20
20
  libmultilabel/nn/__init__.py
21
+ libmultilabel/nn/attentionxml.py
21
22
  libmultilabel/nn/data_utils.py
22
23
  libmultilabel/nn/metrics.py
23
24
  libmultilabel/nn/model.py
@@ -1,10 +1,10 @@
1
1
  [metadata]
2
2
  name = libmultilabel
3
- version = 0.6.2
3
+ version = 0.7.0
4
4
  author = LibMultiLabel Team
5
5
  license = MIT License
6
6
  license_file = LICENSE
7
- description = A library for multi-label text classification
7
+ description = A library for multi-class and multi-label classification
8
8
  long_description = See documentation here: https://www.csie.ntu.edu.tw/~cjlin/libmultilabel
9
9
  url = https://github.com/ASUS-AICS/LibMultiLabel
10
10
  project_urls =
File without changes