mlquantify 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,609 @@
1
+ import os
2
+ import random
3
+ from typing import Dict, Any, Sequence
4
+
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import MSELoss
10
+ from torch.nn.functional import relu
11
+
12
+ from tqdm import tqdm
13
+
14
+ from mlquantify.base import BaseQuantifier
15
+ from mlquantify.base_aggregative import (
16
+ AggregationMixin,
17
+ SoftLearnerQMixin,
18
+ get_aggregation_requirements,
19
+ _get_learner_function
20
+ )
21
+ from mlquantify.utils import (
22
+ validate_y,
23
+ validate_data,
24
+ check_classes_attribute,
25
+ )
26
+ from mlquantify.utils._validation import validate_prevalences
27
+ from mlquantify.model_selection import UPP
28
+ from mlquantify.utils import get_prev_from_labels
29
+ from mlquantify.utils._constraints import Interval, Options
30
+ from mlquantify.utils import _fit_context
31
+
32
+ from mlquantify.adjust_counting import CC, AC, PCC, PAC
33
+ from mlquantify.likelihood import EMQ
34
+
35
+ EPS = 1e-12
36
+
37
+
38
+
39
+ class EarlyStop:
40
+ """
41
+ A class implementing the early-stopping condition typically used for training neural networks.
42
+
43
+ >>> earlystop = EarlyStop(patience=2, lower_is_better=True)
44
+ >>> earlystop(0.9, epoch=0)
45
+ >>> earlystop(0.7, epoch=1)
46
+ >>> earlystop.IMPROVED # is True
47
+ >>> earlystop(1.0, epoch=2)
48
+ >>> earlystop.STOP # is False (patience=1)
49
+ >>> earlystop(1.0, epoch=3)
50
+ >>> earlystop.STOP # is True (patience=0)
51
+ >>> earlystop.best_epoch # is 1
52
+ >>> earlystop.best_score # is 0.7
53
+
54
+ :param patience: the number of (consecutive) times that a monitored evaluation metric (typically obtaind in a
55
+ held-out validation split) can be found to be worse than the best one obtained so far, before flagging the
56
+ stopping condition. An instance of this class is `callable`, and is to be used as follows:
57
+ :param lower_is_better: if True (default) the metric is to be minimized.
58
+ :ivar best_score: keeps track of the best value seen so far
59
+ :ivar best_epoch: keeps track of the epoch in which the best score was set
60
+ :ivar STOP: flag (boolean) indicating the stopping condition
61
+ :ivar IMPROVED: flag (boolean) indicating whether there was an improvement in the last call
62
+ """
63
+
64
+ def __init__(self, patience, lower_is_better=True):
65
+
66
+ self.PATIENCE_LIMIT = patience
67
+ self.better = lambda a,b: a<b if lower_is_better else a>b
68
+ self.patience = patience
69
+ self.best_score = None
70
+ self.best_epoch = None
71
+ self.STOP = False
72
+ self.IMPROVED = False
73
+
74
+ def __call__(self, watch_score, epoch):
75
+ """
76
+ Commits the new score found in epoch `epoch`. If the score improves over the best score found so far, then
77
+ the patiente counter gets reset. If otherwise, the patience counter is decreased, and in case it reachs 0,
78
+ the flag STOP becomes True.
79
+
80
+ :param watch_score: the new score
81
+ :param epoch: the current epoch
82
+ """
83
+ self.IMPROVED = (self.best_score is None or self.better(watch_score, self.best_score))
84
+ if self.IMPROVED:
85
+ self.best_score = watch_score
86
+ self.best_epoch = epoch
87
+ self.patience = self.PATIENCE_LIMIT
88
+ else:
89
+ self.patience -= 1
90
+ if self.patience <= 0:
91
+ self.STOP = True
92
+
93
+
94
+
95
+ class QuaNetModule(nn.Module):
96
+ r"""
97
+ PyTorch module implementing the forward pass of QuaNet, as described in
98
+ Esuli et al. (2018) "A Recurrent Neural Network for Sentiment Quantification". [file:1][file:3]
99
+
100
+ This module takes as input:
101
+ - the document embeddings of a bag,
102
+ - the posterior probabilities for each document in the bag,
103
+ - a fixed-size vector of quantification statistics (e.g., CC/ACC/PCC/PACC outputs),
104
+
105
+ and outputs a class-prevalence vector for the bag.
106
+
107
+ Core idea:
108
+ - Concatenate document embeddings and posterior probabilities.
109
+ - Sort the sequence by the posterior probability of a selected class (optional).
110
+ - Pass the sequence through an LSTM (possibly bidirectional).
111
+ - Take the final hidden state(s) as a "quantification embedding".
112
+ - Concatenate this embedding with the quantification statistics.
113
+ - Pass through one or more fully connected layers and a final softmax to obtain prevalences.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ doc_embedding_size: int,
119
+ n_classes: int,
120
+ stats_size: int,
121
+ lstm_hidden_size: int = 64,
122
+ lstm_nlayers: int = 1,
123
+ ff_layers: Sequence[int] = (1024, 512),
124
+ bidirectional: bool = True,
125
+ qdrop_p: float = 0.5,
126
+ order_by: int | None = 0,
127
+ ) -> None:
128
+ """
129
+ Parameters
130
+ ----------
131
+ doc_embedding_size : int
132
+ Dimensionality of document embeddings (output of `learner.transform`).
133
+ n_classes : int
134
+ Number of classes of the quantification problem.
135
+ stats_size : int
136
+ Dimensionality of the statistics vector concatenated to the LSTM embedding
137
+ (e.g. concatenated prevalence estimates from CC, ACC, PCC, PACC, EMQ, ...).
138
+ lstm_hidden_size : int, default=64
139
+ Hidden size of the LSTM cell(s).
140
+ lstm_nlayers : int, default=1
141
+ Number of stacked LSTM layers.
142
+ ff_layers : sequence of int, default=(1024, 512)
143
+ Sizes of the fully connected layers on top of the quantification embedding.
144
+ bidirectional : bool, default=True
145
+ Whether to use a bidirectional LSTM.
146
+ qdrop_p : float, default=0.5
147
+ Dropout probability used in the LSTM and in the fully connected layers.
148
+ order_by : int or None, default=0
149
+ Index of the class whose posterior probability is used for sorting the sequence.
150
+ If None, no sorting is performed.
151
+ """
152
+ super().__init__()
153
+
154
+ self.n_classes = n_classes
155
+ self.order_by = order_by
156
+ self.hidden_size = lstm_hidden_size
157
+ self.nlayers = lstm_nlayers
158
+ self.bidirectional = bidirectional
159
+ self.ndirections = 2 if self.bidirectional else 1
160
+ self.qdrop_p = qdrop_p
161
+ self.lstm = torch.nn.LSTM(doc_embedding_size + n_classes, # +n_classes stands for the posterior probs. (concatenated)
162
+ lstm_hidden_size, lstm_nlayers, bidirectional=bidirectional,
163
+ dropout=qdrop_p, batch_first=True)
164
+ self.dropout = torch.nn.Dropout(self.qdrop_p)
165
+
166
+ lstm_output_size = self.hidden_size * self.ndirections
167
+ ff_input_size = lstm_output_size + stats_size
168
+ prev_size = ff_input_size
169
+ self.ff_layers = torch.nn.ModuleList()
170
+ for lin_size in ff_layers:
171
+ self.ff_layers.append(torch.nn.Linear(prev_size, lin_size))
172
+ prev_size = lin_size
173
+ self.output = torch.nn.Linear(prev_size, n_classes)
174
+
175
+ @property
176
+ def device(self) -> torch.device:
177
+ """Return the device on which the module parameters are stored."""
178
+ return next(self.parameters()).device
179
+
180
+ def _init_hidden(self) -> tuple[torch.Tensor, torch.Tensor]:
181
+ """
182
+ Initialize LSTM hidden and cell states with zeros.
183
+
184
+ Returns
185
+ -------
186
+ (h0, c0) : (Tensor, Tensor)
187
+ Initial hidden and cell states.
188
+ """
189
+ directions = 2 if self.bidirectional else 1
190
+ var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
191
+ var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
192
+ if next(self.lstm.parameters()).is_cuda:
193
+ var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
194
+ return var_hidden, var_cell
195
+
196
+ def forward(
197
+ self,
198
+ doc_embeddings: np.ndarray | torch.Tensor,
199
+ doc_posteriors: np.ndarray | torch.Tensor,
200
+ statistics: np.ndarray | torch.Tensor,
201
+ ) -> torch.Tensor:
202
+ """
203
+ Forward pass of QuaNet.
204
+
205
+ Parameters
206
+ ----------
207
+ doc_embeddings : array-like of shape (n_docs, emb_dim)
208
+ Document embeddings of all items in the bag.
209
+ doc_posteriors : array-like of shape (n_docs, n_classes)
210
+ Posterior probabilities `P(y | x)` for each document in the bag, produced by the base classifier.
211
+ statistics : array-like of shape (stats_size,) or (1, stats_size)
212
+ Vector of quantification-related statistics (e.g., CC/ACC/PCC/PACC estimates, TPR/FPR, etc.).
213
+
214
+ Returns
215
+ -------
216
+ prevalence : torch.Tensor of shape (1, n_classes)
217
+ Estimated class-prevalence vector for the input bag.
218
+ """
219
+ device = self.device
220
+ doc_embeddings = torch.as_tensor(doc_embeddings, dtype=torch.float, device=device)
221
+ doc_posteriors = torch.as_tensor(doc_posteriors, dtype=torch.float, device=device)
222
+ statistics = torch.as_tensor(statistics, dtype=torch.float, device=device)
223
+
224
+ if self.order_by is not None:
225
+ order = torch.argsort(doc_posteriors[:, self.order_by])
226
+ doc_embeddings = doc_embeddings[order]
227
+ doc_posteriors = doc_posteriors[order]
228
+
229
+ embeded_posteriors = torch.cat((doc_embeddings, doc_posteriors), dim=-1)
230
+
231
+ # the entire set represents only one instance in quapy contexts, and so the batch_size=1
232
+ # the shape should be (1, number-of-instances, embedding-size + n_classes)
233
+
234
+ embeded_posteriors = embeded_posteriors.unsqueeze(0)
235
+
236
+ self.lstm.flatten_parameters()
237
+ _, (rnn_hidden,_) = self.lstm(embeded_posteriors, self._init_hidden())
238
+ rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
239
+ quant_embedding = rnn_hidden[0].view(-1)
240
+ quant_embedding = torch.cat((quant_embedding, statistics))
241
+
242
+ abstracted = quant_embedding.unsqueeze(0)
243
+
244
+ for linear in self.ff_layers:
245
+ abstracted = self.dropout(relu(linear(abstracted)))
246
+
247
+ logits = self.output(abstracted).view(1, -1)
248
+ prevalence = torch.softmax(logits, -1)
249
+
250
+ return prevalence
251
+
252
+
253
+ class QuaNet(SoftLearnerQMixin, AggregationMixin, BaseQuantifier):
254
+ r"""
255
+ QuaNetQuantifier: a deep quantification method following the QuaNet architecture,
256
+ implemented in the `mlquantify` style.
257
+
258
+ This class wraps a base probabilistic learner that:
259
+ - can be trained on labeled instances,
260
+ - can output posterior probabilities via `predict_proba(X)`,
261
+ - can transform instances into embeddings via `transform(X)`.
262
+
263
+ QuaNet then learns a mapping from bags of instances to class-prevalence vectors by:
264
+ - generating artificial bags using the APP protocol (APP: Artificial Prevalence Protocol),
265
+ - computing simple quantification estimates (CC, ACC, PCC, PACC, ...) on each bag,
266
+ - feeding the sequence of (embedding, posterior) pairs and the statistics vector into an LSTM-based network,
267
+ - minimizing a bag-level quantification loss (MSE between predicted and true prevalences).[file:1][file:3]
268
+
269
+ Parameters
270
+ ----------
271
+ learner : estimator
272
+ Base probabilistic classifier. Must implement:
273
+ - fit(X, y),
274
+ - predict_proba(X) -> array-like (n_samples, n_classes),
275
+ - transform(X) -> array-like (n_samples, emb_dim).
276
+ fit_learner : bool, default=True
277
+ If True, the learner is trained inside QuaNetQuantifier.fit.
278
+ If False, it is assumed to be already fitted.
279
+ sample_size : int, default=100
280
+ Bag size used by the APP protocol during QuaNet training.
281
+ n_epochs : int, default=100
282
+ Maximum number of QuaNet training epochs.
283
+ tr_iter : int, default=500
284
+ Number of APP samplings (training iterations) per epoch.
285
+ va_iter : int, default=100
286
+ Number of APP samplings (validation iterations) per epoch.
287
+ lr : float, default=1e-3
288
+ Learning rate for the Adam optimizer.
289
+ lstm_hidden_size : int, default=64
290
+ Hidden size of the QuaNet LSTM.
291
+ lstm_nlayers : int, default=1
292
+ Number of layers in the QuaNet LSTM.
293
+ ff_layers : sequence of int, default=(1024, 512)
294
+ Sizes of the fully connected layers on top of the LSTM quantification embedding.
295
+ bidirectional : bool, default=True
296
+ Whether to use a bidirectional LSTM.
297
+ qdrop_p : float, default=0.5
298
+ Dropout probability used in QuaNet.
299
+ patience : int, default=10
300
+ Early-stopping patience in number of epochs without validation improvement.
301
+ checkpointdir : str, default="./checkpoint_quanet"
302
+ Directory where intermediate QuaNet weights are stored.
303
+ checkpointname : str or None, default=None
304
+ Name of the saved checkpoint file. If None, a random name is generated.
305
+ device : {"cpu", "cuda"}, default="cuda"
306
+ Device on which to run the QuaNet model.
307
+ """
308
+
309
+ _parameter_constraints = {
310
+ "fit_learner": [Interval(0, None, inclusive_left=False), Options([None])],
311
+ "sample_size": [Interval(0, None, inclusive_left=False), Options([None])],
312
+ "n_epochs": [Interval(0, None, inclusive_left=False), Options([None])],
313
+ "tr_iter": [Interval(0, None, inclusive_left=False), Options([None])],
314
+ "va_iter": [Interval(0, None, inclusive_left=False), Options([None])],
315
+ "lr": [Interval(0, None, inclusive_left=False), Options([None])],
316
+ "lstm_hidden_size": [Interval(0, None, inclusive_left=False), Options([None])],
317
+ "lstm_nlayers": [Interval(0, None, inclusive_left=False), Options([None])],
318
+ "bidirectional": [Interval(0, None, inclusive_left=False), Options([None])],
319
+ "qdrop_p": [Interval(0, None, inclusive_left=False), Options([None])],
320
+ "patience": [Interval(0, None, inclusive_left=False), Options([None])],
321
+ "checkpointdir": ["string"],
322
+ "checkpointname": ["string"],
323
+ }
324
+
325
+
326
+ def __init__(
327
+ self,
328
+ learner,
329
+ fit_learner: bool = True,
330
+ sample_size: int = 100,
331
+ n_epochs: int = 100,
332
+ tr_iter: int = 500,
333
+ va_iter: int = 100,
334
+ lr: float = 1e-3,
335
+ lstm_hidden_size: int = 64,
336
+ lstm_nlayers: int = 1,
337
+ ff_layers: Sequence[int] = (1024, 512),
338
+ bidirectional: bool = True,
339
+ random_state: int = None,
340
+ qdrop_p: float = 0.5,
341
+ patience: int = 10,
342
+ checkpointdir: str = "./checkpoint_quanet",
343
+ checkpointname: str | None = None,
344
+ device: str = "cuda",
345
+ ) -> None:
346
+
347
+ assert hasattr(learner, "transform"), ...
348
+ assert hasattr(learner, "predict_proba"), ...
349
+
350
+ # save hyperparameters as attributes
351
+ self.learner = learner
352
+ self.fit_learner = fit_learner
353
+ self.sample_size = sample_size
354
+ self.n_epochs = n_epochs
355
+ self.tr_iter = tr_iter
356
+ self.va_iter = va_iter
357
+ self.lr = lr
358
+ self.lstm_hidden_size = lstm_hidden_size
359
+ self.lstm_nlayers = lstm_nlayers
360
+ self.ff_layers = ff_layers
361
+ self.bidirectional = bidirectional
362
+ self.random_state = random_state
363
+ self.qdrop_p = qdrop_p
364
+ self.patience = patience
365
+ self.checkpointdir = checkpointdir
366
+ self.checkpointname = checkpointname
367
+ self.device = torch.device(device)
368
+
369
+ self.quanet_params: Dict[str, Any] = dict(
370
+ lstm_hidden_size=lstm_hidden_size,
371
+ lstm_nlayers=lstm_nlayers,
372
+ ff_layers=ff_layers,
373
+ bidirectional=bidirectional,
374
+ qdrop_p=qdrop_p,
375
+ )
376
+
377
+ os.makedirs(self.checkpointdir, exist_ok=True)
378
+ if self.checkpointname is None:
379
+ local_random = random.Random()
380
+ random_code = "-".join(str(local_random.randint(0, 1_000_000)) for _ in range(5))
381
+ self.checkpointname = f"QuaNet-{random_code}"
382
+ self.checkpoint = os.path.join(self.checkpointdir, self.checkpointname)
383
+
384
+ self._classes_ = None
385
+ self.quantifiers = {}
386
+ self.quanet = None
387
+ self.optim = None
388
+
389
+ self.status: Dict[str, float] = {
390
+ "tr-loss": -1.0,
391
+ "va-loss": -1.0,
392
+ "tr-mae": -1.0,
393
+ "va-mae": -1.0,
394
+ }
395
+
396
+ @_fit_context(prefer_skip_nested_validation=True)
397
+ def fit(self, X, y):
398
+ y = validate_data(self, y=y)
399
+ self.classes_ = check_classes_attribute(self, np.unique(y))
400
+
401
+ os.makedirs(self.checkpointdir, exist_ok=True)
402
+
403
+ if self.fit_learner:
404
+ X_clf, X_rest, y_clf, y_rest = train_test_split(
405
+ X, y, test_size=0.4, random_state=self.random_state, stratify=y
406
+ )
407
+ X_train, X_val, y_train, y_val = train_test_split(
408
+ X_rest, y_rest, test_size=0.2, random_state=self.random_state, stratify=y_rest
409
+ )
410
+
411
+ self.learner.fit(X_clf, y_clf)
412
+ else:
413
+ X_train, X_val, y_train, y_val = train_test_split(
414
+ X, y, test_size=0.40, random_state=self.random_state, stratify=y
415
+ )
416
+
417
+ self.tr_prev = get_prev_from_labels(y, format="array")
418
+
419
+ # **CORREÇÃO: Obter embeddings e suas dimensões**
420
+ X_train_embeddings = self.learner.transform(X_train)
421
+ X_val_embeddings = self.learner.transform(X_val)
422
+
423
+ valid_posteriors = self.learner.predict_proba(X_val)
424
+ train_posteriors = self.learner.predict_proba(X_train)
425
+
426
+ self.val_posteriors = valid_posteriors
427
+ self.y_val = y_val
428
+
429
+ self.quantifiers = {
430
+ "cc": CC(self.learner),
431
+ "acc": AC(self.learner),
432
+ "pcc": PCC(self.learner),
433
+ "pacc": PAC(self.learner),
434
+ "emq": EMQ(self.learner),
435
+ }
436
+
437
+ self.status = {
438
+ "tr-loss": -1.0,
439
+ "va-loss": -1.0,
440
+ "tr-mae": -1.0,
441
+ "va-mae": -1.0,
442
+ }
443
+
444
+ numQtf = len(self.quantifiers)
445
+ numClasses = len(self.classes_)
446
+
447
+ # **CORREÇÃO: Use a dimensão dos embeddings, não das features originais**
448
+ self.quanet = QuaNetModule(
449
+ doc_embedding_size=X_train_embeddings.shape[1], # ← MUDANÇA AQUI
450
+ n_classes=numClasses,
451
+ stats_size=numQtf*numClasses,
452
+ order_by=0 if numClasses == 2 else None,
453
+ **self.quanet_params
454
+ ).to(self.device)
455
+ print(self.quanet)
456
+
457
+ self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
458
+ early_stop = EarlyStop(
459
+ patience=self.patience,
460
+ lower_is_better=True,
461
+ )
462
+
463
+ checkpoint = self.checkpoint
464
+
465
+ for epoch in range(self.n_epochs):
466
+ # **CORREÇÃO: Passar embeddings em vez de X original**
467
+ self._epoch(
468
+ X_train_embeddings, y_train, train_posteriors,
469
+ self.tr_iter, epoch, early_stop, train=True
470
+ )
471
+ self._epoch(
472
+ X_val_embeddings, y_val, valid_posteriors,
473
+ self.va_iter, epoch, early_stop, train=False
474
+ )
475
+
476
+ early_stop(self.status["va-loss"], epoch)
477
+ if early_stop.IMPROVED:
478
+ torch.save(self.quanet.state_dict(), checkpoint)
479
+ elif early_stop.STOP:
480
+ print(f'Training ended at epoch {early_stop.best_epoch}, loading best model parameters in {checkpoint}')
481
+ self.quanet.load_state_dict(torch.load(checkpoint))
482
+ break
483
+
484
+ return self
485
+
486
+ def _aggregate_qtf(self, posteriors, train_posteriors, y_train):
487
+ qtf_estims = []
488
+
489
+ for name, qtf in self.quantifiers.items():
490
+
491
+ requirements = get_aggregation_requirements(qtf)
492
+
493
+ if requirements.requires_train_proba and requirements.requires_train_labels:
494
+ prev = qtf.aggregate(posteriors, train_posteriors, y_train)
495
+ elif requirements.requires_train_labels:
496
+ prev = qtf.aggregate(posteriors, y_train)
497
+ else:
498
+ prev = qtf.aggregate(posteriors)
499
+
500
+ qtf_estims.extend(np.asarray(list(prev.values())))
501
+
502
+ return qtf_estims
503
+
504
+
505
+ def predict(self, X):
506
+
507
+ learner_function = _get_learner_function(self)
508
+ posteriors = getattr(self.learner, learner_function)(X)
509
+ embeddings = self.learner.transform(X)
510
+
511
+ qtf_estims = self._aggregate_qtf(posteriors, self.val_posteriors, self.y_val)
512
+
513
+ self.quanet.eval()
514
+ with torch.no_grad():
515
+ prevalence = self.quanet.forward(embeddings, posteriors, qtf_estims)
516
+ if self.device.type == "cuda":
517
+ prevalence = prevalence.cpu()
518
+ prevalence = prevalence.numpy().flatten()
519
+
520
+ return prevalence
521
+
522
+
523
+ def _epoch(self, X, y, posteriors, iterations, epoch, early_stop, train: bool) -> None:
524
+ mse_loss = MSELoss()
525
+
526
+ self.quanet.train(mode=train)
527
+ losses = []
528
+ mae_errors = []
529
+
530
+ sampler = UPP(
531
+ batch_size=self.sample_size,
532
+ n_prevalences=iterations,
533
+ random_state= None if train else self.random_state,
534
+ )
535
+
536
+ for idx in sampler.split(X, y):
537
+ X_batch = X[idx]
538
+ y_batch = y[idx]
539
+ posteriors_batch = posteriors[idx]
540
+
541
+ qtf_estims = self._aggregate_qtf(posteriors_batch, self.val_posteriors, self.y_val)
542
+
543
+ p_true = torch.as_tensor(
544
+ get_prev_from_labels(y_batch, format="array", classes=self.classes_),
545
+ dtype=torch.float,
546
+ device=self.device
547
+ ).unsqueeze(0)
548
+
549
+ if train:
550
+ self.optim.zero_grad()
551
+ p_pred = self.quanet.forward(
552
+ X_batch,
553
+ posteriors_batch,
554
+ qtf_estims
555
+ )
556
+ loss = mse_loss(p_pred, p_true)
557
+ mae = mae_loss(p_pred, p_true)
558
+ loss.backward()
559
+ self.optim.step()
560
+ else:
561
+ with torch.no_grad():
562
+ p_pred = self.quanet.forward(
563
+ X_batch,
564
+ posteriors_batch,
565
+ qtf_estims
566
+ )
567
+ loss = mse_loss(p_pred, p_true)
568
+ mae = mae_loss(p_pred, p_true)
569
+
570
+ losses.append(loss.item())
571
+ mae_errors.append(mae.item())
572
+
573
+ mae = np.mean(mae_errors)
574
+ mse = np.mean(losses)
575
+
576
+ if train:
577
+ self.status["tr-mae"] = mae
578
+ self.status["tr-loss"] = mse
579
+ else:
580
+ self.status["va-mae"] = mae
581
+ self.status["va-loss"] = mse
582
+
583
+
584
+ def _check_params_colision(self, quanet_params, learner_params):
585
+ quanet_keys = set(quanet_params.keys())
586
+ learner_keys = set(learner_params.keys())
587
+
588
+ colision_keys = quanet_keys.intersection(learner_keys)
589
+
590
+ if colision_keys:
591
+ raise ValueError(f"Parameters {colision_keys} are present in both quanet_params and learner_params")
592
+
593
+ def clean_checkpoint(self):
594
+ if os.path.exists(self.checkpoint):
595
+ os.remove(self.checkpoint)
596
+
597
+ def clean_checkpoint_dir(self):
598
+ import shutil
599
+ shutil.rmtree(self.checkpointdir, ignore_errors=True)
600
+
601
+
602
+ def mae_loss(y_true, y_pred):
603
+ return torch.mean(torch.abs(y_true - y_pred))
604
+
605
+
606
+
607
+
608
+
609
+
File without changes
File without changes
@@ -43,5 +43,6 @@ from mlquantify.utils._validation import (
43
43
  _is_arraylike_not_scalar,
44
44
  _is_arraylike,
45
45
  validate_data,
46
- check_classes_attribute
46
+ check_classes_attribute,
47
+ validate_prevalences
47
48
  )
@@ -198,6 +198,8 @@ class StringConstraint:
198
198
  return sp.issparse(value)
199
199
  if self.keyword == "boolean":
200
200
  return isinstance(value, bool)
201
+ if self.keyword == "string":
202
+ return isinstance(value, str)
201
203
  if self.keyword == "random_state":
202
204
  return isinstance(value, (np.random.RandomState, int, type(None)))
203
205
  if self.keyword == "nan":
@@ -266,6 +266,13 @@ def _is_arraylike(x):
266
266
  return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__")
267
267
 
268
268
 
269
+ def _transform_if_float(y):
270
+ """Transform y to integers if it is float."""
271
+ if np.issubdtype(y.dtype, np.floating):
272
+ y = y.astype(str)
273
+ return y
274
+
275
+
269
276
  def validate_data(quantifier,
270
277
  X="no_validation",
271
278
  y="no_validation",
@@ -316,8 +323,10 @@ def validate_data(quantifier,
316
323
  if "estimator" not in check_y_params:
317
324
  check_y_params = {**default_check_params, **check_y_params}
318
325
  y = check_array(y, input_name="y", **check_y_params)
326
+ y = _transform_if_float(y)
319
327
  else:
320
328
  X, y = check_X_y(X, y, dtype=None, **check_params)
329
+ y = _transform_if_float(y)
321
330
  out = X, y
322
331
 
323
332
  return out