mlquantify 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/__init__.py +2 -1
- mlquantify/adjust_counting/__init__.py +6 -5
- mlquantify/adjust_counting/_adjustment.py +208 -37
- mlquantify/adjust_counting/_base.py +5 -6
- mlquantify/adjust_counting/_counting.py +10 -7
- mlquantify/likelihood/__init__.py +0 -2
- mlquantify/likelihood/_classes.py +45 -199
- mlquantify/meta/_classes.py +12 -12
- mlquantify/mixture/__init__.py +2 -1
- mlquantify/mixture/_classes.py +310 -15
- mlquantify/model_selection/_search.py +1 -1
- mlquantify/neighbors/_base.py +15 -15
- mlquantify/neighbors/_classes.py +2 -2
- mlquantify/neighbors/_kde.py +6 -6
- mlquantify/neural/__init__.py +1 -1
- mlquantify/neural/_base.py +0 -0
- mlquantify/neural/_classes.py +611 -0
- mlquantify/neural/_perm_invariant.py +0 -0
- mlquantify/neural/_utils.py +0 -0
- mlquantify/utils/__init__.py +2 -1
- mlquantify/utils/_constraints.py +2 -0
- mlquantify/utils/_validation.py +9 -0
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/METADATA +13 -18
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/RECORD +27 -23
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/WHEEL +1 -1
- mlquantify-0.1.22.dist-info/licenses/LICENSE +28 -0
- mlquantify/likelihood/_base.py +0 -147
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from typing import Dict, Any, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from sklearn.model_selection import train_test_split
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch.nn import MSELoss
|
|
11
|
+
from torch.nn.functional import relu
|
|
12
|
+
except ImportError:
|
|
13
|
+
import warnings
|
|
14
|
+
warnings.warn("torch is not installed. Neural network quantifiers will not work.")
|
|
15
|
+
|
|
16
|
+
from mlquantify.base import BaseQuantifier
|
|
17
|
+
from mlquantify.base_aggregative import (
|
|
18
|
+
AggregationMixin,
|
|
19
|
+
SoftLearnerQMixin,
|
|
20
|
+
get_aggregation_requirements,
|
|
21
|
+
_get_learner_function
|
|
22
|
+
)
|
|
23
|
+
from mlquantify.utils import (
|
|
24
|
+
validate_y,
|
|
25
|
+
validate_data,
|
|
26
|
+
check_classes_attribute,
|
|
27
|
+
)
|
|
28
|
+
from mlquantify.utils._validation import validate_prevalences
|
|
29
|
+
from mlquantify.model_selection import UPP
|
|
30
|
+
from mlquantify.utils import get_prev_from_labels
|
|
31
|
+
from mlquantify.utils._constraints import Interval, Options
|
|
32
|
+
from mlquantify.utils import _fit_context
|
|
33
|
+
|
|
34
|
+
from mlquantify.adjust_counting import CC, AC, PCC, PAC
|
|
35
|
+
from mlquantify.likelihood import EMQ
|
|
36
|
+
|
|
37
|
+
EPS = 1e-12
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class EarlyStop:
|
|
42
|
+
"""
|
|
43
|
+
A class implementing the early-stopping condition typically used for training neural networks.
|
|
44
|
+
|
|
45
|
+
>>> earlystop = EarlyStop(patience=2, lower_is_better=True)
|
|
46
|
+
>>> earlystop(0.9, epoch=0)
|
|
47
|
+
>>> earlystop(0.7, epoch=1)
|
|
48
|
+
>>> earlystop.IMPROVED # is True
|
|
49
|
+
>>> earlystop(1.0, epoch=2)
|
|
50
|
+
>>> earlystop.STOP # is False (patience=1)
|
|
51
|
+
>>> earlystop(1.0, epoch=3)
|
|
52
|
+
>>> earlystop.STOP # is True (patience=0)
|
|
53
|
+
>>> earlystop.best_epoch # is 1
|
|
54
|
+
>>> earlystop.best_score # is 0.7
|
|
55
|
+
|
|
56
|
+
:param patience: the number of (consecutive) times that a monitored evaluation metric (typically obtaind in a
|
|
57
|
+
held-out validation split) can be found to be worse than the best one obtained so far, before flagging the
|
|
58
|
+
stopping condition. An instance of this class is `callable`, and is to be used as follows:
|
|
59
|
+
:param lower_is_better: if True (default) the metric is to be minimized.
|
|
60
|
+
:ivar best_score: keeps track of the best value seen so far
|
|
61
|
+
:ivar best_epoch: keeps track of the epoch in which the best score was set
|
|
62
|
+
:ivar STOP: flag (boolean) indicating the stopping condition
|
|
63
|
+
:ivar IMPROVED: flag (boolean) indicating whether there was an improvement in the last call
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, patience, lower_is_better=True):
|
|
67
|
+
|
|
68
|
+
self.PATIENCE_LIMIT = patience
|
|
69
|
+
self.better = lambda a,b: a<b if lower_is_better else a>b
|
|
70
|
+
self.patience = patience
|
|
71
|
+
self.best_score = None
|
|
72
|
+
self.best_epoch = None
|
|
73
|
+
self.STOP = False
|
|
74
|
+
self.IMPROVED = False
|
|
75
|
+
|
|
76
|
+
def __call__(self, watch_score, epoch):
|
|
77
|
+
"""
|
|
78
|
+
Commits the new score found in epoch `epoch`. If the score improves over the best score found so far, then
|
|
79
|
+
the patiente counter gets reset. If otherwise, the patience counter is decreased, and in case it reachs 0,
|
|
80
|
+
the flag STOP becomes True.
|
|
81
|
+
|
|
82
|
+
:param watch_score: the new score
|
|
83
|
+
:param epoch: the current epoch
|
|
84
|
+
"""
|
|
85
|
+
self.IMPROVED = (self.best_score is None or self.better(watch_score, self.best_score))
|
|
86
|
+
if self.IMPROVED:
|
|
87
|
+
self.best_score = watch_score
|
|
88
|
+
self.best_epoch = epoch
|
|
89
|
+
self.patience = self.PATIENCE_LIMIT
|
|
90
|
+
else:
|
|
91
|
+
self.patience -= 1
|
|
92
|
+
if self.patience <= 0:
|
|
93
|
+
self.STOP = True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class QuaNetModule(nn.Module):
|
|
98
|
+
r"""
|
|
99
|
+
PyTorch module implementing the forward pass of QuaNet, as described in
|
|
100
|
+
Esuli et al. (2018) "A Recurrent Neural Network for Sentiment Quantification". [file:1][file:3]
|
|
101
|
+
|
|
102
|
+
This module takes as input:
|
|
103
|
+
- the document embeddings of a bag,
|
|
104
|
+
- the posterior probabilities for each document in the bag,
|
|
105
|
+
- a fixed-size vector of quantification statistics (e.g., CC/ACC/PCC/PACC outputs),
|
|
106
|
+
|
|
107
|
+
and outputs a class-prevalence vector for the bag.
|
|
108
|
+
|
|
109
|
+
Core idea:
|
|
110
|
+
- Concatenate document embeddings and posterior probabilities.
|
|
111
|
+
- Sort the sequence by the posterior probability of a selected class (optional).
|
|
112
|
+
- Pass the sequence through an LSTM (possibly bidirectional).
|
|
113
|
+
- Take the final hidden state(s) as a "quantification embedding".
|
|
114
|
+
- Concatenate this embedding with the quantification statistics.
|
|
115
|
+
- Pass through one or more fully connected layers and a final softmax to obtain prevalences.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
doc_embedding_size: int,
|
|
121
|
+
n_classes: int,
|
|
122
|
+
stats_size: int,
|
|
123
|
+
lstm_hidden_size: int = 64,
|
|
124
|
+
lstm_nlayers: int = 1,
|
|
125
|
+
ff_layers: Sequence[int] = (1024, 512),
|
|
126
|
+
bidirectional: bool = True,
|
|
127
|
+
qdrop_p: float = 0.5,
|
|
128
|
+
order_by: int | None = 0,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
doc_embedding_size : int
|
|
134
|
+
Dimensionality of document embeddings (output of `learner.transform`).
|
|
135
|
+
n_classes : int
|
|
136
|
+
Number of classes of the quantification problem.
|
|
137
|
+
stats_size : int
|
|
138
|
+
Dimensionality of the statistics vector concatenated to the LSTM embedding
|
|
139
|
+
(e.g. concatenated prevalence estimates from CC, ACC, PCC, PACC, EMQ, ...).
|
|
140
|
+
lstm_hidden_size : int, default=64
|
|
141
|
+
Hidden size of the LSTM cell(s).
|
|
142
|
+
lstm_nlayers : int, default=1
|
|
143
|
+
Number of stacked LSTM layers.
|
|
144
|
+
ff_layers : sequence of int, default=(1024, 512)
|
|
145
|
+
Sizes of the fully connected layers on top of the quantification embedding.
|
|
146
|
+
bidirectional : bool, default=True
|
|
147
|
+
Whether to use a bidirectional LSTM.
|
|
148
|
+
qdrop_p : float, default=0.5
|
|
149
|
+
Dropout probability used in the LSTM and in the fully connected layers.
|
|
150
|
+
order_by : int or None, default=0
|
|
151
|
+
Index of the class whose posterior probability is used for sorting the sequence.
|
|
152
|
+
If None, no sorting is performed.
|
|
153
|
+
"""
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
self.n_classes = n_classes
|
|
157
|
+
self.order_by = order_by
|
|
158
|
+
self.hidden_size = lstm_hidden_size
|
|
159
|
+
self.nlayers = lstm_nlayers
|
|
160
|
+
self.bidirectional = bidirectional
|
|
161
|
+
self.ndirections = 2 if self.bidirectional else 1
|
|
162
|
+
self.qdrop_p = qdrop_p
|
|
163
|
+
self.lstm = torch.nn.LSTM(doc_embedding_size + n_classes, # +n_classes stands for the posterior probs. (concatenated)
|
|
164
|
+
lstm_hidden_size, lstm_nlayers, bidirectional=bidirectional,
|
|
165
|
+
dropout=qdrop_p, batch_first=True)
|
|
166
|
+
self.dropout = torch.nn.Dropout(self.qdrop_p)
|
|
167
|
+
|
|
168
|
+
lstm_output_size = self.hidden_size * self.ndirections
|
|
169
|
+
ff_input_size = lstm_output_size + stats_size
|
|
170
|
+
prev_size = ff_input_size
|
|
171
|
+
self.ff_layers = torch.nn.ModuleList()
|
|
172
|
+
for lin_size in ff_layers:
|
|
173
|
+
self.ff_layers.append(torch.nn.Linear(prev_size, lin_size))
|
|
174
|
+
prev_size = lin_size
|
|
175
|
+
self.output = torch.nn.Linear(prev_size, n_classes)
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def device(self) -> torch.device:
|
|
179
|
+
"""Return the device on which the module parameters are stored."""
|
|
180
|
+
return next(self.parameters()).device
|
|
181
|
+
|
|
182
|
+
def _init_hidden(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
183
|
+
"""
|
|
184
|
+
Initialize LSTM hidden and cell states with zeros.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
(h0, c0) : (Tensor, Tensor)
|
|
189
|
+
Initial hidden and cell states.
|
|
190
|
+
"""
|
|
191
|
+
directions = 2 if self.bidirectional else 1
|
|
192
|
+
var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
|
193
|
+
var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
|
194
|
+
if next(self.lstm.parameters()).is_cuda:
|
|
195
|
+
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
|
|
196
|
+
return var_hidden, var_cell
|
|
197
|
+
|
|
198
|
+
def forward(
|
|
199
|
+
self,
|
|
200
|
+
doc_embeddings: np.ndarray | torch.Tensor,
|
|
201
|
+
doc_posteriors: np.ndarray | torch.Tensor,
|
|
202
|
+
statistics: np.ndarray | torch.Tensor,
|
|
203
|
+
) -> torch.Tensor:
|
|
204
|
+
"""
|
|
205
|
+
Forward pass of QuaNet.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
doc_embeddings : array-like of shape (n_docs, emb_dim)
|
|
210
|
+
Document embeddings of all items in the bag.
|
|
211
|
+
doc_posteriors : array-like of shape (n_docs, n_classes)
|
|
212
|
+
Posterior probabilities `P(y | x)` for each document in the bag, produced by the base classifier.
|
|
213
|
+
statistics : array-like of shape (stats_size,) or (1, stats_size)
|
|
214
|
+
Vector of quantification-related statistics (e.g., CC/ACC/PCC/PACC estimates, TPR/FPR, etc.).
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
prevalence : torch.Tensor of shape (1, n_classes)
|
|
219
|
+
Estimated class-prevalence vector for the input bag.
|
|
220
|
+
"""
|
|
221
|
+
device = self.device
|
|
222
|
+
doc_embeddings = torch.as_tensor(doc_embeddings, dtype=torch.float, device=device)
|
|
223
|
+
doc_posteriors = torch.as_tensor(doc_posteriors, dtype=torch.float, device=device)
|
|
224
|
+
statistics = torch.as_tensor(statistics, dtype=torch.float, device=device)
|
|
225
|
+
|
|
226
|
+
if self.order_by is not None:
|
|
227
|
+
order = torch.argsort(doc_posteriors[:, self.order_by])
|
|
228
|
+
doc_embeddings = doc_embeddings[order]
|
|
229
|
+
doc_posteriors = doc_posteriors[order]
|
|
230
|
+
|
|
231
|
+
embeded_posteriors = torch.cat((doc_embeddings, doc_posteriors), dim=-1)
|
|
232
|
+
|
|
233
|
+
# the entire set represents only one instance in quapy contexts, and so the batch_size=1
|
|
234
|
+
# the shape should be (1, number-of-instances, embedding-size + n_classes)
|
|
235
|
+
|
|
236
|
+
embeded_posteriors = embeded_posteriors.unsqueeze(0)
|
|
237
|
+
|
|
238
|
+
self.lstm.flatten_parameters()
|
|
239
|
+
_, (rnn_hidden,_) = self.lstm(embeded_posteriors, self._init_hidden())
|
|
240
|
+
rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
|
|
241
|
+
quant_embedding = rnn_hidden[0].view(-1)
|
|
242
|
+
quant_embedding = torch.cat((quant_embedding, statistics))
|
|
243
|
+
|
|
244
|
+
abstracted = quant_embedding.unsqueeze(0)
|
|
245
|
+
|
|
246
|
+
for linear in self.ff_layers:
|
|
247
|
+
abstracted = self.dropout(relu(linear(abstracted)))
|
|
248
|
+
|
|
249
|
+
logits = self.output(abstracted).view(1, -1)
|
|
250
|
+
prevalence = torch.softmax(logits, -1)
|
|
251
|
+
|
|
252
|
+
return prevalence
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class QuaNet(SoftLearnerQMixin, AggregationMixin, BaseQuantifier):
|
|
256
|
+
r"""
|
|
257
|
+
QuaNetQuantifier: a deep quantification method following the QuaNet architecture,
|
|
258
|
+
implemented in the `mlquantify` style.
|
|
259
|
+
|
|
260
|
+
This class wraps a base probabilistic learner that:
|
|
261
|
+
- can be trained on labeled instances,
|
|
262
|
+
- can output posterior probabilities via `predict_proba(X)`,
|
|
263
|
+
- can transform instances into embeddings via `transform(X)`.
|
|
264
|
+
|
|
265
|
+
QuaNet then learns a mapping from bags of instances to class-prevalence vectors by:
|
|
266
|
+
- generating artificial bags using the APP protocol (APP: Artificial Prevalence Protocol),
|
|
267
|
+
- computing simple quantification estimates (CC, ACC, PCC, PACC, ...) on each bag,
|
|
268
|
+
- feeding the sequence of (embedding, posterior) pairs and the statistics vector into an LSTM-based network,
|
|
269
|
+
- minimizing a bag-level quantification loss (MSE between predicted and true prevalences).[file:1][file:3]
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
learner : estimator
|
|
274
|
+
Base probabilistic classifier. Must implement:
|
|
275
|
+
- fit(X, y),
|
|
276
|
+
- predict_proba(X) -> array-like (n_samples, n_classes),
|
|
277
|
+
- transform(X) -> array-like (n_samples, emb_dim).
|
|
278
|
+
fit_learner : bool, default=True
|
|
279
|
+
If True, the learner is trained inside QuaNetQuantifier.fit.
|
|
280
|
+
If False, it is assumed to be already fitted.
|
|
281
|
+
sample_size : int, default=100
|
|
282
|
+
Bag size used by the APP protocol during QuaNet training.
|
|
283
|
+
n_epochs : int, default=100
|
|
284
|
+
Maximum number of QuaNet training epochs.
|
|
285
|
+
tr_iter : int, default=500
|
|
286
|
+
Number of APP samplings (training iterations) per epoch.
|
|
287
|
+
va_iter : int, default=100
|
|
288
|
+
Number of APP samplings (validation iterations) per epoch.
|
|
289
|
+
lr : float, default=1e-3
|
|
290
|
+
Learning rate for the Adam optimizer.
|
|
291
|
+
lstm_hidden_size : int, default=64
|
|
292
|
+
Hidden size of the QuaNet LSTM.
|
|
293
|
+
lstm_nlayers : int, default=1
|
|
294
|
+
Number of layers in the QuaNet LSTM.
|
|
295
|
+
ff_layers : sequence of int, default=(1024, 512)
|
|
296
|
+
Sizes of the fully connected layers on top of the LSTM quantification embedding.
|
|
297
|
+
bidirectional : bool, default=True
|
|
298
|
+
Whether to use a bidirectional LSTM.
|
|
299
|
+
qdrop_p : float, default=0.5
|
|
300
|
+
Dropout probability used in QuaNet.
|
|
301
|
+
patience : int, default=10
|
|
302
|
+
Early-stopping patience in number of epochs without validation improvement.
|
|
303
|
+
checkpointdir : str, default="./checkpoint_quanet"
|
|
304
|
+
Directory where intermediate QuaNet weights are stored.
|
|
305
|
+
checkpointname : str or None, default=None
|
|
306
|
+
Name of the saved checkpoint file. If None, a random name is generated.
|
|
307
|
+
device : {"cpu", "cuda"}, default="cuda"
|
|
308
|
+
Device on which to run the QuaNet model.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
_parameter_constraints = {
|
|
312
|
+
"fit_learner": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
313
|
+
"sample_size": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
314
|
+
"n_epochs": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
315
|
+
"tr_iter": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
316
|
+
"va_iter": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
317
|
+
"lr": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
318
|
+
"lstm_hidden_size": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
319
|
+
"lstm_nlayers": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
320
|
+
"bidirectional": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
321
|
+
"qdrop_p": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
322
|
+
"patience": [Interval(0, None, inclusive_left=False), Options([None])],
|
|
323
|
+
"checkpointdir": ["string"],
|
|
324
|
+
"checkpointname": ["string"],
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
learner,
|
|
331
|
+
fit_learner: bool = True,
|
|
332
|
+
sample_size: int = 100,
|
|
333
|
+
n_epochs: int = 100,
|
|
334
|
+
tr_iter: int = 500,
|
|
335
|
+
va_iter: int = 100,
|
|
336
|
+
lr: float = 1e-3,
|
|
337
|
+
lstm_hidden_size: int = 64,
|
|
338
|
+
lstm_nlayers: int = 1,
|
|
339
|
+
ff_layers: Sequence[int] = (1024, 512),
|
|
340
|
+
bidirectional: bool = True,
|
|
341
|
+
random_state: int = None,
|
|
342
|
+
qdrop_p: float = 0.5,
|
|
343
|
+
patience: int = 10,
|
|
344
|
+
checkpointdir: str = "./checkpoint_quanet",
|
|
345
|
+
checkpointname: str | None = None,
|
|
346
|
+
device: str = "cuda",
|
|
347
|
+
) -> None:
|
|
348
|
+
|
|
349
|
+
assert hasattr(learner, "transform"), ...
|
|
350
|
+
assert hasattr(learner, "predict_proba"), ...
|
|
351
|
+
|
|
352
|
+
# save hyperparameters as attributes
|
|
353
|
+
self.learner = learner
|
|
354
|
+
self.fit_learner = fit_learner
|
|
355
|
+
self.sample_size = sample_size
|
|
356
|
+
self.n_epochs = n_epochs
|
|
357
|
+
self.tr_iter = tr_iter
|
|
358
|
+
self.va_iter = va_iter
|
|
359
|
+
self.lr = lr
|
|
360
|
+
self.lstm_hidden_size = lstm_hidden_size
|
|
361
|
+
self.lstm_nlayers = lstm_nlayers
|
|
362
|
+
self.ff_layers = ff_layers
|
|
363
|
+
self.bidirectional = bidirectional
|
|
364
|
+
self.random_state = random_state
|
|
365
|
+
self.qdrop_p = qdrop_p
|
|
366
|
+
self.patience = patience
|
|
367
|
+
self.checkpointdir = checkpointdir
|
|
368
|
+
self.checkpointname = checkpointname
|
|
369
|
+
self.device = torch.device(device)
|
|
370
|
+
|
|
371
|
+
self.quanet_params: Dict[str, Any] = dict(
|
|
372
|
+
lstm_hidden_size=lstm_hidden_size,
|
|
373
|
+
lstm_nlayers=lstm_nlayers,
|
|
374
|
+
ff_layers=ff_layers,
|
|
375
|
+
bidirectional=bidirectional,
|
|
376
|
+
qdrop_p=qdrop_p,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
os.makedirs(self.checkpointdir, exist_ok=True)
|
|
380
|
+
if self.checkpointname is None:
|
|
381
|
+
local_random = random.Random()
|
|
382
|
+
random_code = "-".join(str(local_random.randint(0, 1_000_000)) for _ in range(5))
|
|
383
|
+
self.checkpointname = f"QuaNet-{random_code}"
|
|
384
|
+
self.checkpoint = os.path.join(self.checkpointdir, self.checkpointname)
|
|
385
|
+
|
|
386
|
+
self._classes_ = None
|
|
387
|
+
self.quantifiers = {}
|
|
388
|
+
self.quanet = None
|
|
389
|
+
self.optim = None
|
|
390
|
+
|
|
391
|
+
self.status: Dict[str, float] = {
|
|
392
|
+
"tr-loss": -1.0,
|
|
393
|
+
"va-loss": -1.0,
|
|
394
|
+
"tr-mae": -1.0,
|
|
395
|
+
"va-mae": -1.0,
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
@_fit_context(prefer_skip_nested_validation=True)
|
|
399
|
+
def fit(self, X, y):
|
|
400
|
+
y = validate_data(self, y=y)
|
|
401
|
+
self.classes_ = check_classes_attribute(self, np.unique(y))
|
|
402
|
+
|
|
403
|
+
os.makedirs(self.checkpointdir, exist_ok=True)
|
|
404
|
+
|
|
405
|
+
if self.fit_learner:
|
|
406
|
+
X_clf, X_rest, y_clf, y_rest = train_test_split(
|
|
407
|
+
X, y, test_size=0.4, random_state=self.random_state, stratify=y
|
|
408
|
+
)
|
|
409
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
|
410
|
+
X_rest, y_rest, test_size=0.2, random_state=self.random_state, stratify=y_rest
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
self.learner.fit(X_clf, y_clf)
|
|
414
|
+
else:
|
|
415
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
|
416
|
+
X, y, test_size=0.40, random_state=self.random_state, stratify=y
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
self.tr_prev = get_prev_from_labels(y, format="array")
|
|
420
|
+
|
|
421
|
+
# **CORREÇÃO: Obter embeddings e suas dimensões**
|
|
422
|
+
X_train_embeddings = self.learner.transform(X_train)
|
|
423
|
+
X_val_embeddings = self.learner.transform(X_val)
|
|
424
|
+
|
|
425
|
+
valid_posteriors = self.learner.predict_proba(X_val)
|
|
426
|
+
train_posteriors = self.learner.predict_proba(X_train)
|
|
427
|
+
|
|
428
|
+
self.val_posteriors = valid_posteriors
|
|
429
|
+
self.y_val = y_val
|
|
430
|
+
|
|
431
|
+
self.quantifiers = {
|
|
432
|
+
"cc": CC(self.learner),
|
|
433
|
+
"acc": AC(self.learner),
|
|
434
|
+
"pcc": PCC(self.learner),
|
|
435
|
+
"pacc": PAC(self.learner),
|
|
436
|
+
"emq": EMQ(self.learner),
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
self.status = {
|
|
440
|
+
"tr-loss": -1.0,
|
|
441
|
+
"va-loss": -1.0,
|
|
442
|
+
"tr-mae": -1.0,
|
|
443
|
+
"va-mae": -1.0,
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
numQtf = len(self.quantifiers)
|
|
447
|
+
numClasses = len(self.classes_)
|
|
448
|
+
|
|
449
|
+
# **CORREÇÃO: Use a dimensão dos embeddings, não das features originais**
|
|
450
|
+
self.quanet = QuaNetModule(
|
|
451
|
+
doc_embedding_size=X_train_embeddings.shape[1], # ← MUDANÇA AQUI
|
|
452
|
+
n_classes=numClasses,
|
|
453
|
+
stats_size=numQtf*numClasses,
|
|
454
|
+
order_by=0 if numClasses == 2 else None,
|
|
455
|
+
**self.quanet_params
|
|
456
|
+
).to(self.device)
|
|
457
|
+
print(self.quanet)
|
|
458
|
+
|
|
459
|
+
self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
|
|
460
|
+
early_stop = EarlyStop(
|
|
461
|
+
patience=self.patience,
|
|
462
|
+
lower_is_better=True,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
checkpoint = self.checkpoint
|
|
466
|
+
|
|
467
|
+
for epoch in range(self.n_epochs):
|
|
468
|
+
# **CORREÇÃO: Passar embeddings em vez de X original**
|
|
469
|
+
self._epoch(
|
|
470
|
+
X_train_embeddings, y_train, train_posteriors,
|
|
471
|
+
self.tr_iter, epoch, early_stop, train=True
|
|
472
|
+
)
|
|
473
|
+
self._epoch(
|
|
474
|
+
X_val_embeddings, y_val, valid_posteriors,
|
|
475
|
+
self.va_iter, epoch, early_stop, train=False
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
early_stop(self.status["va-loss"], epoch)
|
|
479
|
+
if early_stop.IMPROVED:
|
|
480
|
+
torch.save(self.quanet.state_dict(), checkpoint)
|
|
481
|
+
elif early_stop.STOP:
|
|
482
|
+
print(f'Training ended at epoch {early_stop.best_epoch}, loading best model parameters in {checkpoint}')
|
|
483
|
+
self.quanet.load_state_dict(torch.load(checkpoint))
|
|
484
|
+
break
|
|
485
|
+
|
|
486
|
+
return self
|
|
487
|
+
|
|
488
|
+
def _aggregate_qtf(self, posteriors, train_posteriors, y_train):
|
|
489
|
+
qtf_estims = []
|
|
490
|
+
|
|
491
|
+
for name, qtf in self.quantifiers.items():
|
|
492
|
+
|
|
493
|
+
requirements = get_aggregation_requirements(qtf)
|
|
494
|
+
|
|
495
|
+
if requirements.requires_train_proba and requirements.requires_train_labels:
|
|
496
|
+
prev = qtf.aggregate(posteriors, train_posteriors, y_train)
|
|
497
|
+
elif requirements.requires_train_labels:
|
|
498
|
+
prev = qtf.aggregate(posteriors, y_train)
|
|
499
|
+
else:
|
|
500
|
+
prev = qtf.aggregate(posteriors)
|
|
501
|
+
|
|
502
|
+
qtf_estims.extend(np.asarray(list(prev.values())))
|
|
503
|
+
|
|
504
|
+
return qtf_estims
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def predict(self, X):
|
|
508
|
+
|
|
509
|
+
learner_function = _get_learner_function(self)
|
|
510
|
+
posteriors = getattr(self.learner, learner_function)(X)
|
|
511
|
+
embeddings = self.learner.transform(X)
|
|
512
|
+
|
|
513
|
+
qtf_estims = self._aggregate_qtf(posteriors, self.val_posteriors, self.y_val)
|
|
514
|
+
|
|
515
|
+
self.quanet.eval()
|
|
516
|
+
with torch.no_grad():
|
|
517
|
+
prevalence = self.quanet.forward(embeddings, posteriors, qtf_estims)
|
|
518
|
+
if self.device.type == "cuda":
|
|
519
|
+
prevalence = prevalence.cpu()
|
|
520
|
+
prevalence = prevalence.numpy().flatten()
|
|
521
|
+
|
|
522
|
+
return prevalence
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _epoch(self, X, y, posteriors, iterations, epoch, early_stop, train: bool) -> None:
|
|
526
|
+
mse_loss = MSELoss()
|
|
527
|
+
|
|
528
|
+
self.quanet.train(mode=train)
|
|
529
|
+
losses = []
|
|
530
|
+
mae_errors = []
|
|
531
|
+
|
|
532
|
+
sampler = UPP(
|
|
533
|
+
batch_size=self.sample_size,
|
|
534
|
+
n_prevalences=iterations,
|
|
535
|
+
random_state= None if train else self.random_state,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
for idx in sampler.split(X, y):
|
|
539
|
+
X_batch = X[idx]
|
|
540
|
+
y_batch = y[idx]
|
|
541
|
+
posteriors_batch = posteriors[idx]
|
|
542
|
+
|
|
543
|
+
qtf_estims = self._aggregate_qtf(posteriors_batch, self.val_posteriors, self.y_val)
|
|
544
|
+
|
|
545
|
+
p_true = torch.as_tensor(
|
|
546
|
+
get_prev_from_labels(y_batch, format="array", classes=self.classes_),
|
|
547
|
+
dtype=torch.float,
|
|
548
|
+
device=self.device
|
|
549
|
+
).unsqueeze(0)
|
|
550
|
+
|
|
551
|
+
if train:
|
|
552
|
+
self.optim.zero_grad()
|
|
553
|
+
p_pred = self.quanet.forward(
|
|
554
|
+
X_batch,
|
|
555
|
+
posteriors_batch,
|
|
556
|
+
qtf_estims
|
|
557
|
+
)
|
|
558
|
+
loss = mse_loss(p_pred, p_true)
|
|
559
|
+
mae = mae_loss(p_pred, p_true)
|
|
560
|
+
loss.backward()
|
|
561
|
+
self.optim.step()
|
|
562
|
+
else:
|
|
563
|
+
with torch.no_grad():
|
|
564
|
+
p_pred = self.quanet.forward(
|
|
565
|
+
X_batch,
|
|
566
|
+
posteriors_batch,
|
|
567
|
+
qtf_estims
|
|
568
|
+
)
|
|
569
|
+
loss = mse_loss(p_pred, p_true)
|
|
570
|
+
mae = mae_loss(p_pred, p_true)
|
|
571
|
+
|
|
572
|
+
losses.append(loss.item())
|
|
573
|
+
mae_errors.append(mae.item())
|
|
574
|
+
|
|
575
|
+
mae = np.mean(mae_errors)
|
|
576
|
+
mse = np.mean(losses)
|
|
577
|
+
|
|
578
|
+
if train:
|
|
579
|
+
self.status["tr-mae"] = mae
|
|
580
|
+
self.status["tr-loss"] = mse
|
|
581
|
+
else:
|
|
582
|
+
self.status["va-mae"] = mae
|
|
583
|
+
self.status["va-loss"] = mse
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _check_params_colision(self, quanet_params, learner_params):
|
|
587
|
+
quanet_keys = set(quanet_params.keys())
|
|
588
|
+
learner_keys = set(learner_params.keys())
|
|
589
|
+
|
|
590
|
+
colision_keys = quanet_keys.intersection(learner_keys)
|
|
591
|
+
|
|
592
|
+
if colision_keys:
|
|
593
|
+
raise ValueError(f"Parameters {colision_keys} are present in both quanet_params and learner_params")
|
|
594
|
+
|
|
595
|
+
def clean_checkpoint(self):
|
|
596
|
+
if os.path.exists(self.checkpoint):
|
|
597
|
+
os.remove(self.checkpoint)
|
|
598
|
+
|
|
599
|
+
def clean_checkpoint_dir(self):
|
|
600
|
+
import shutil
|
|
601
|
+
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def mae_loss(y_true, y_pred):
|
|
605
|
+
return torch.mean(torch.abs(y_true - y_pred))
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
|
|
File without changes
|
|
File without changes
|
mlquantify/utils/__init__.py
CHANGED
mlquantify/utils/_constraints.py
CHANGED
|
@@ -198,6 +198,8 @@ class StringConstraint:
|
|
|
198
198
|
return sp.issparse(value)
|
|
199
199
|
if self.keyword == "boolean":
|
|
200
200
|
return isinstance(value, bool)
|
|
201
|
+
if self.keyword == "string":
|
|
202
|
+
return isinstance(value, str)
|
|
201
203
|
if self.keyword == "random_state":
|
|
202
204
|
return isinstance(value, (np.random.RandomState, int, type(None)))
|
|
203
205
|
if self.keyword == "nan":
|
mlquantify/utils/_validation.py
CHANGED
|
@@ -266,6 +266,13 @@ def _is_arraylike(x):
|
|
|
266
266
|
return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__")
|
|
267
267
|
|
|
268
268
|
|
|
269
|
+
def _transform_if_float(y):
|
|
270
|
+
"""Transform y to integers if it is float."""
|
|
271
|
+
if np.issubdtype(y.dtype, np.floating):
|
|
272
|
+
y = y.astype(str)
|
|
273
|
+
return y
|
|
274
|
+
|
|
275
|
+
|
|
269
276
|
def validate_data(quantifier,
|
|
270
277
|
X="no_validation",
|
|
271
278
|
y="no_validation",
|
|
@@ -316,8 +323,10 @@ def validate_data(quantifier,
|
|
|
316
323
|
if "estimator" not in check_y_params:
|
|
317
324
|
check_y_params = {**default_check_params, **check_y_params}
|
|
318
325
|
y = check_array(y, input_name="y", **check_y_params)
|
|
326
|
+
y = _transform_if_float(y)
|
|
319
327
|
else:
|
|
320
328
|
X, y = check_X_y(X, y, dtype=None, **check_params)
|
|
329
|
+
y = _transform_if_float(y)
|
|
321
330
|
out = X, y
|
|
322
331
|
|
|
323
332
|
return out
|