libmultilabel 0.5.1__tar.gz → 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/PKG-INFO +1 -1
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/common_utils.py +15 -30
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/data_utils.py +7 -12
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/metrics.py +7 -9
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel.egg-info/PKG-INFO +1 -1
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/setup.cfg +1 -1
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/LICENSE +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/README.md +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/__init__.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/__init__.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/data_utils.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/linear.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/metrics.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/preprocessor.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/tree.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/linear/utils.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/logging.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/__init__.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/model.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/__init__.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/bert.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/bert_attention.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/caml.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/kim_cnn.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/labelwise_attention_networks.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/modules.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/networks/xml_cnn.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel/nn/nn_utils.py +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel.egg-info/SOURCES.txt +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel.egg-info/dependency_links.txt +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel.egg-info/requires.txt +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/libmultilabel.egg-info/top_level.txt +0 -0
- {libmultilabel-0.5.1 → libmultilabel-0.5.2}/pyproject.toml +0 -0
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
|
+
from functools import wraps
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
|
|
@@ -41,36 +42,6 @@ class AttributeDict(dict):
|
|
|
41
42
|
return {k: self[k] for k in self._used}
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
class Timer(object):
|
|
45
|
-
"""Computes elasped time."""
|
|
46
|
-
|
|
47
|
-
def __init__(self):
|
|
48
|
-
self.reset()
|
|
49
|
-
|
|
50
|
-
def reset(self):
|
|
51
|
-
self.running = True
|
|
52
|
-
self.total = 0
|
|
53
|
-
self.start = time.time()
|
|
54
|
-
return self
|
|
55
|
-
|
|
56
|
-
def resume(self):
|
|
57
|
-
if not self.running:
|
|
58
|
-
self.running = True
|
|
59
|
-
self.start = time.time()
|
|
60
|
-
return self
|
|
61
|
-
|
|
62
|
-
def stop(self):
|
|
63
|
-
if self.running:
|
|
64
|
-
self.running = False
|
|
65
|
-
self.total += time.time() - self.start
|
|
66
|
-
return self
|
|
67
|
-
|
|
68
|
-
def time(self):
|
|
69
|
-
if self.running:
|
|
70
|
-
return self.total + time.time() - self.start
|
|
71
|
-
return self.total
|
|
72
|
-
|
|
73
|
-
|
|
74
45
|
def dump_log(log_path, metrics=None, split=None, config=None):
|
|
75
46
|
"""Write log including the used items of config and the evaluation scores.
|
|
76
47
|
|
|
@@ -156,3 +127,17 @@ def is_multiclass_dataset(dataset, label="label"):
|
|
|
156
127
|
a multi-class problem."""
|
|
157
128
|
)
|
|
158
129
|
return ratio == 1.0
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def timer(func):
|
|
133
|
+
"""Log info-level wall time"""
|
|
134
|
+
|
|
135
|
+
@wraps(func)
|
|
136
|
+
def wrapper(*args, **kwargs):
|
|
137
|
+
start_time = time.time()
|
|
138
|
+
value = func(*args, **kwargs)
|
|
139
|
+
wall_time = time.time() - start_time
|
|
140
|
+
logging.info(f"{repr(func.__name__)} finished in {wall_time:.2f} seconds")
|
|
141
|
+
return value
|
|
142
|
+
|
|
143
|
+
return wrapper
|
|
@@ -2,7 +2,6 @@ import csv
|
|
|
2
2
|
import gc
|
|
3
3
|
import logging
|
|
4
4
|
import warnings
|
|
5
|
-
from concurrent.futures import ProcessPoolExecutor
|
|
6
5
|
|
|
7
6
|
import pandas as pd
|
|
8
7
|
import torch
|
|
@@ -159,7 +158,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
|
|
|
159
158
|
This is effective only when is_test=False. Defaults to False.
|
|
160
159
|
|
|
161
160
|
Returns:
|
|
162
|
-
|
|
161
|
+
dict: [{(optional: "index": ..., )"label": ..., "text": ...}, ...]
|
|
163
162
|
"""
|
|
164
163
|
assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe."
|
|
165
164
|
if isinstance(data, str):
|
|
@@ -176,9 +175,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
|
|
|
176
175
|
|
|
177
176
|
data["label"] = data["label"].astype(str).map(lambda s: s.split())
|
|
178
177
|
if tokenize_text:
|
|
179
|
-
|
|
180
|
-
with ProcessPoolExecutor() as executor:
|
|
181
|
-
data["text"] = pd.Series(tqdm(executor.map(tokenize, data["text"]), total=len(data["text"])))
|
|
178
|
+
data["text"] = data["text"].map(tokenize)
|
|
182
179
|
data = data.to_dict("records")
|
|
183
180
|
if not is_test:
|
|
184
181
|
num_no_label_data = sum(1 for d in data if len(d["label"]) == 0)
|
|
@@ -222,15 +219,12 @@ def load_datasets(
|
|
|
222
219
|
Returns:
|
|
223
220
|
dict: A dictionary of datasets.
|
|
224
221
|
"""
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
|
|
228
|
-
assert (
|
|
229
|
-
not training_data.empty or not test_data.empty
|
|
230
|
-
), "At least one of `training_data` and `test_data` must be specified."
|
|
222
|
+
if training_data is None and test_data is None:
|
|
223
|
+
raise ValueError("At least one of `training_data` and `test_data` must be specified.")
|
|
231
224
|
|
|
232
225
|
datasets = {}
|
|
233
226
|
if training_data is not None:
|
|
227
|
+
logging.info(f"Loading training data")
|
|
234
228
|
datasets["train"] = _load_raw_data(
|
|
235
229
|
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
|
|
236
230
|
)
|
|
@@ -243,11 +237,12 @@ def load_datasets(
|
|
|
243
237
|
datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42)
|
|
244
238
|
|
|
245
239
|
if test_data is not None:
|
|
240
|
+
logging.info(f"Loading test data")
|
|
246
241
|
datasets["test"] = _load_raw_data(
|
|
247
242
|
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
|
|
248
243
|
)
|
|
249
244
|
|
|
250
|
-
if merge_train_val:
|
|
245
|
+
if merge_train_val and "val" in datasets:
|
|
251
246
|
datasets["train"] = datasets["train"] + datasets["val"]
|
|
252
247
|
for i in range(len(datasets["train"])):
|
|
253
248
|
datasets["train"][i]["index"] = i
|
|
@@ -217,19 +217,17 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
|
|
|
217
217
|
|
|
218
218
|
if match_top_k:
|
|
219
219
|
metric_abbr = match_top_k.group(1) # P, R, PR, or nDCG
|
|
220
|
-
|
|
221
|
-
if
|
|
222
|
-
raise ValueError(
|
|
223
|
-
f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})."
|
|
224
|
-
)
|
|
220
|
+
k = int(match_top_k.group(2))
|
|
221
|
+
if k >= num_classes:
|
|
222
|
+
raise ValueError(f"Invalid metric: {metric}. k ({k}) is greater than num_classes({num_classes}).")
|
|
225
223
|
if metric_abbr == "P":
|
|
226
|
-
metrics[metric] = Precision(num_classes, average="samples", top_k=
|
|
224
|
+
metrics[metric] = Precision(num_classes, average="samples", top_k=k)
|
|
227
225
|
elif metric_abbr == "R":
|
|
228
|
-
metrics[metric] = Recall(num_classes, average="samples", top_k=
|
|
226
|
+
metrics[metric] = Recall(num_classes, average="samples", top_k=k)
|
|
229
227
|
elif metric_abbr == "RP":
|
|
230
|
-
metrics[metric] = RPrecision(top_k=
|
|
228
|
+
metrics[metric] = RPrecision(top_k=k)
|
|
231
229
|
elif metric_abbr == "nDCG":
|
|
232
|
-
metrics[metric] = NDCG(top_k=
|
|
230
|
+
metrics[metric] = NDCG(top_k=k)
|
|
233
231
|
# The implementation in torchmetrics stores the prediction/target of all batches,
|
|
234
232
|
# which can lead to CUDA out of memory.
|
|
235
233
|
# metrics[metric] = RetrievalNormalizedDCG(k=top_k)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|