libmultilabel 0.5.0__tar.gz → 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/PKG-INFO +1 -1
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/common_utils.py +15 -30
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/metrics.py +7 -3
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/data_utils.py +7 -12
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/metrics.py +7 -9
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel.egg-info/PKG-INFO +1 -1
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/setup.cfg +1 -1
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/LICENSE +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/README.md +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/__init__.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/__init__.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/data_utils.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/linear.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/preprocessor.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/tree.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/linear/utils.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/logging.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/__init__.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/model.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/__init__.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/bert.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/bert_attention.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/caml.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/kim_cnn.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/labelwise_attention_networks.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/modules.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/networks/xml_cnn.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel/nn/nn_utils.py +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel.egg-info/SOURCES.txt +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel.egg-info/dependency_links.txt +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel.egg-info/requires.txt +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/libmultilabel.egg-info/top_level.txt +0 -0
- {libmultilabel-0.5.0 → libmultilabel-0.5.2}/pyproject.toml +0 -0
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
|
+
from functools import wraps
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
|
|
@@ -41,36 +42,6 @@ class AttributeDict(dict):
|
|
|
41
42
|
return {k: self[k] for k in self._used}
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
class Timer(object):
|
|
45
|
-
"""Computes elasped time."""
|
|
46
|
-
|
|
47
|
-
def __init__(self):
|
|
48
|
-
self.reset()
|
|
49
|
-
|
|
50
|
-
def reset(self):
|
|
51
|
-
self.running = True
|
|
52
|
-
self.total = 0
|
|
53
|
-
self.start = time.time()
|
|
54
|
-
return self
|
|
55
|
-
|
|
56
|
-
def resume(self):
|
|
57
|
-
if not self.running:
|
|
58
|
-
self.running = True
|
|
59
|
-
self.start = time.time()
|
|
60
|
-
return self
|
|
61
|
-
|
|
62
|
-
def stop(self):
|
|
63
|
-
if self.running:
|
|
64
|
-
self.running = False
|
|
65
|
-
self.total += time.time() - self.start
|
|
66
|
-
return self
|
|
67
|
-
|
|
68
|
-
def time(self):
|
|
69
|
-
if self.running:
|
|
70
|
-
return self.total + time.time() - self.start
|
|
71
|
-
return self.total
|
|
72
|
-
|
|
73
|
-
|
|
74
45
|
def dump_log(log_path, metrics=None, split=None, config=None):
|
|
75
46
|
"""Write log including the used items of config and the evaluation scores.
|
|
76
47
|
|
|
@@ -156,3 +127,17 @@ def is_multiclass_dataset(dataset, label="label"):
|
|
|
156
127
|
a multi-class problem."""
|
|
157
128
|
)
|
|
158
129
|
return ratio == 1.0
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def timer(func):
|
|
133
|
+
"""Log info-level wall time"""
|
|
134
|
+
|
|
135
|
+
@wraps(func)
|
|
136
|
+
def wrapper(*args, **kwargs):
|
|
137
|
+
start_time = time.time()
|
|
138
|
+
value = func(*args, **kwargs)
|
|
139
|
+
wall_time = time.time() - start_time
|
|
140
|
+
logging.info(f"{repr(func.__name__)} finished in {wall_time:.2f} seconds")
|
|
141
|
+
return value
|
|
142
|
+
|
|
143
|
+
return wrapper
|
|
@@ -64,6 +64,7 @@ class NDCG:
|
|
|
64
64
|
dcg = _DCG_argsort(argsort_preds, target, self.top_k)
|
|
65
65
|
idcg = _IDCG(target, self.top_k)
|
|
66
66
|
ndcg_score = dcg / idcg
|
|
67
|
+
# by convention, ndcg is 0 for zero label instances
|
|
67
68
|
self.score += np.nan_to_num(ndcg_score, nan=0.0).sum()
|
|
68
69
|
self.num_sample += argsort_preds.shape[0]
|
|
69
70
|
|
|
@@ -95,6 +96,7 @@ class RPrecision:
|
|
|
95
96
|
def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
|
|
96
97
|
top_k_idx = argsort_preds[:, -self.top_k :]
|
|
97
98
|
num_relevant = np.take_along_axis(target, top_k_idx, axis=-1).sum(axis=-1) # (batch_size, )
|
|
99
|
+
# by convention, rprecision is 0 for zero label instances
|
|
98
100
|
self.score += np.nan_to_num(num_relevant / np.minimum(self.top_k, target.sum(axis=-1)), nan=0.0).sum()
|
|
99
101
|
self.num_sample += argsort_preds.shape[0]
|
|
100
102
|
|
|
@@ -167,7 +169,8 @@ class Recall:
|
|
|
167
169
|
def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
|
|
168
170
|
top_k_idx = argsort_preds[:, -self.top_k :]
|
|
169
171
|
num_relevant = np.take_along_axis(target, top_k_idx, -1).sum(axis=-1)
|
|
170
|
-
|
|
172
|
+
# by convention, recall is 0 for zero label instances
|
|
173
|
+
self.score += np.nan_to_num(num_relevant / target.sum(axis=-1), nan=0.0).sum()
|
|
171
174
|
self.num_sample += argsort_preds.shape[0]
|
|
172
175
|
|
|
173
176
|
def compute(self) -> float:
|
|
@@ -210,14 +213,15 @@ class F1:
|
|
|
210
213
|
def compute(self) -> float:
|
|
211
214
|
prev_settings = np.seterr("ignore")
|
|
212
215
|
|
|
216
|
+
# F1 is 0 for the cases where there are no positive instances
|
|
213
217
|
if self.average == "macro":
|
|
214
218
|
score = np.nansum(2 * self.tp / (2 * self.tp + self.fp + self.fn)) / self.num_classes
|
|
215
219
|
elif self.average == "micro":
|
|
216
|
-
score = np.nan_to_num(2 * np.sum(self.tp) / np.sum(2 * self.tp + self.fp + self.fn))
|
|
220
|
+
score = np.nan_to_num(2 * np.sum(self.tp) / np.sum(2 * self.tp + self.fp + self.fn), nan=0.0)
|
|
217
221
|
elif self.average == "another-macro":
|
|
218
222
|
macro_prec = np.nansum(self.tp / (self.tp + self.fp)) / self.num_classes
|
|
219
223
|
macro_recall = np.nansum(self.tp / (self.tp + self.fn)) / self.num_classes
|
|
220
|
-
score = np.nan_to_num(2 * macro_prec * macro_recall / (macro_prec + macro_recall))
|
|
224
|
+
score = np.nan_to_num(2 * macro_prec * macro_recall / (macro_prec + macro_recall), nan=0.0)
|
|
221
225
|
|
|
222
226
|
np.seterr(**prev_settings)
|
|
223
227
|
return score
|
|
@@ -2,7 +2,6 @@ import csv
|
|
|
2
2
|
import gc
|
|
3
3
|
import logging
|
|
4
4
|
import warnings
|
|
5
|
-
from concurrent.futures import ProcessPoolExecutor
|
|
6
5
|
|
|
7
6
|
import pandas as pd
|
|
8
7
|
import torch
|
|
@@ -159,7 +158,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
|
|
|
159
158
|
This is effective only when is_test=False. Defaults to False.
|
|
160
159
|
|
|
161
160
|
Returns:
|
|
162
|
-
|
|
161
|
+
dict: [{(optional: "index": ..., )"label": ..., "text": ...}, ...]
|
|
163
162
|
"""
|
|
164
163
|
assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe."
|
|
165
164
|
if isinstance(data, str):
|
|
@@ -176,9 +175,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
|
|
|
176
175
|
|
|
177
176
|
data["label"] = data["label"].astype(str).map(lambda s: s.split())
|
|
178
177
|
if tokenize_text:
|
|
179
|
-
|
|
180
|
-
with ProcessPoolExecutor() as executor:
|
|
181
|
-
data["text"] = pd.Series(tqdm(executor.map(tokenize, data["text"]), total=len(data["text"])))
|
|
178
|
+
data["text"] = data["text"].map(tokenize)
|
|
182
179
|
data = data.to_dict("records")
|
|
183
180
|
if not is_test:
|
|
184
181
|
num_no_label_data = sum(1 for d in data if len(d["label"]) == 0)
|
|
@@ -222,15 +219,12 @@ def load_datasets(
|
|
|
222
219
|
Returns:
|
|
223
220
|
dict: A dictionary of datasets.
|
|
224
221
|
"""
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
|
|
228
|
-
assert (
|
|
229
|
-
not training_data.empty or not test_data.empty
|
|
230
|
-
), "At least one of `training_data` and `test_data` must be specified."
|
|
222
|
+
if training_data is None and test_data is None:
|
|
223
|
+
raise ValueError("At least one of `training_data` and `test_data` must be specified.")
|
|
231
224
|
|
|
232
225
|
datasets = {}
|
|
233
226
|
if training_data is not None:
|
|
227
|
+
logging.info(f"Loading training data")
|
|
234
228
|
datasets["train"] = _load_raw_data(
|
|
235
229
|
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
|
|
236
230
|
)
|
|
@@ -243,11 +237,12 @@ def load_datasets(
|
|
|
243
237
|
datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42)
|
|
244
238
|
|
|
245
239
|
if test_data is not None:
|
|
240
|
+
logging.info(f"Loading test data")
|
|
246
241
|
datasets["test"] = _load_raw_data(
|
|
247
242
|
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
|
|
248
243
|
)
|
|
249
244
|
|
|
250
|
-
if merge_train_val:
|
|
245
|
+
if merge_train_val and "val" in datasets:
|
|
251
246
|
datasets["train"] = datasets["train"] + datasets["val"]
|
|
252
247
|
for i in range(len(datasets["train"])):
|
|
253
248
|
datasets["train"][i]["index"] = i
|
|
@@ -217,19 +217,17 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
|
|
|
217
217
|
|
|
218
218
|
if match_top_k:
|
|
219
219
|
metric_abbr = match_top_k.group(1) # P, R, PR, or nDCG
|
|
220
|
-
|
|
221
|
-
if
|
|
222
|
-
raise ValueError(
|
|
223
|
-
f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})."
|
|
224
|
-
)
|
|
220
|
+
k = int(match_top_k.group(2))
|
|
221
|
+
if k >= num_classes:
|
|
222
|
+
raise ValueError(f"Invalid metric: {metric}. k ({k}) is greater than num_classes({num_classes}).")
|
|
225
223
|
if metric_abbr == "P":
|
|
226
|
-
metrics[metric] = Precision(num_classes, average="samples", top_k=
|
|
224
|
+
metrics[metric] = Precision(num_classes, average="samples", top_k=k)
|
|
227
225
|
elif metric_abbr == "R":
|
|
228
|
-
metrics[metric] = Recall(num_classes, average="samples", top_k=
|
|
226
|
+
metrics[metric] = Recall(num_classes, average="samples", top_k=k)
|
|
229
227
|
elif metric_abbr == "RP":
|
|
230
|
-
metrics[metric] = RPrecision(top_k=
|
|
228
|
+
metrics[metric] = RPrecision(top_k=k)
|
|
231
229
|
elif metric_abbr == "nDCG":
|
|
232
|
-
metrics[metric] = NDCG(top_k=
|
|
230
|
+
metrics[metric] = NDCG(top_k=k)
|
|
233
231
|
# The implementation in torchmetrics stores the prediction/target of all batches,
|
|
234
232
|
# which can lead to CUDA out of memory.
|
|
235
233
|
# metrics[metric] = RetrievalNormalizedDCG(k=top_k)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|