mlquantify 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,611 @@
1
+ import os
2
+ import random
3
+ from typing import Dict, Any, Sequence
4
+
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split
7
+ try:
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import MSELoss
11
+ from torch.nn.functional import relu
12
+ except ImportError:
13
+ import warnings
14
+ warnings.warn("torch is not installed. Neural network quantifiers will not work.")
15
+
16
+ from mlquantify.base import BaseQuantifier
17
+ from mlquantify.base_aggregative import (
18
+ AggregationMixin,
19
+ SoftLearnerQMixin,
20
+ get_aggregation_requirements,
21
+ _get_learner_function
22
+ )
23
+ from mlquantify.utils import (
24
+ validate_y,
25
+ validate_data,
26
+ check_classes_attribute,
27
+ )
28
+ from mlquantify.utils._validation import validate_prevalences
29
+ from mlquantify.model_selection import UPP
30
+ from mlquantify.utils import get_prev_from_labels
31
+ from mlquantify.utils._constraints import Interval, Options
32
+ from mlquantify.utils import _fit_context
33
+
34
+ from mlquantify.adjust_counting import CC, AC, PCC, PAC
35
+ from mlquantify.likelihood import EMQ
36
+
37
+ EPS = 1e-12
38
+
39
+
40
+
41
+ class EarlyStop:
42
+ """
43
+ A class implementing the early-stopping condition typically used for training neural networks.
44
+
45
+ >>> earlystop = EarlyStop(patience=2, lower_is_better=True)
46
+ >>> earlystop(0.9, epoch=0)
47
+ >>> earlystop(0.7, epoch=1)
48
+ >>> earlystop.IMPROVED # is True
49
+ >>> earlystop(1.0, epoch=2)
50
+ >>> earlystop.STOP # is False (patience=1)
51
+ >>> earlystop(1.0, epoch=3)
52
+ >>> earlystop.STOP # is True (patience=0)
53
+ >>> earlystop.best_epoch # is 1
54
+ >>> earlystop.best_score # is 0.7
55
+
56
+ :param patience: the number of (consecutive) times that a monitored evaluation metric (typically obtaind in a
57
+ held-out validation split) can be found to be worse than the best one obtained so far, before flagging the
58
+ stopping condition. An instance of this class is `callable`, and is to be used as follows:
59
+ :param lower_is_better: if True (default) the metric is to be minimized.
60
+ :ivar best_score: keeps track of the best value seen so far
61
+ :ivar best_epoch: keeps track of the epoch in which the best score was set
62
+ :ivar STOP: flag (boolean) indicating the stopping condition
63
+ :ivar IMPROVED: flag (boolean) indicating whether there was an improvement in the last call
64
+ """
65
+
66
+ def __init__(self, patience, lower_is_better=True):
67
+
68
+ self.PATIENCE_LIMIT = patience
69
+ self.better = lambda a,b: a<b if lower_is_better else a>b
70
+ self.patience = patience
71
+ self.best_score = None
72
+ self.best_epoch = None
73
+ self.STOP = False
74
+ self.IMPROVED = False
75
+
76
+ def __call__(self, watch_score, epoch):
77
+ """
78
+ Commits the new score found in epoch `epoch`. If the score improves over the best score found so far, then
79
+ the patiente counter gets reset. If otherwise, the patience counter is decreased, and in case it reachs 0,
80
+ the flag STOP becomes True.
81
+
82
+ :param watch_score: the new score
83
+ :param epoch: the current epoch
84
+ """
85
+ self.IMPROVED = (self.best_score is None or self.better(watch_score, self.best_score))
86
+ if self.IMPROVED:
87
+ self.best_score = watch_score
88
+ self.best_epoch = epoch
89
+ self.patience = self.PATIENCE_LIMIT
90
+ else:
91
+ self.patience -= 1
92
+ if self.patience <= 0:
93
+ self.STOP = True
94
+
95
+
96
+
97
+ class QuaNetModule(nn.Module):
98
+ r"""
99
+ PyTorch module implementing the forward pass of QuaNet, as described in
100
+ Esuli et al. (2018) "A Recurrent Neural Network for Sentiment Quantification". [file:1][file:3]
101
+
102
+ This module takes as input:
103
+ - the document embeddings of a bag,
104
+ - the posterior probabilities for each document in the bag,
105
+ - a fixed-size vector of quantification statistics (e.g., CC/ACC/PCC/PACC outputs),
106
+
107
+ and outputs a class-prevalence vector for the bag.
108
+
109
+ Core idea:
110
+ - Concatenate document embeddings and posterior probabilities.
111
+ - Sort the sequence by the posterior probability of a selected class (optional).
112
+ - Pass the sequence through an LSTM (possibly bidirectional).
113
+ - Take the final hidden state(s) as a "quantification embedding".
114
+ - Concatenate this embedding with the quantification statistics.
115
+ - Pass through one or more fully connected layers and a final softmax to obtain prevalences.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ doc_embedding_size: int,
121
+ n_classes: int,
122
+ stats_size: int,
123
+ lstm_hidden_size: int = 64,
124
+ lstm_nlayers: int = 1,
125
+ ff_layers: Sequence[int] = (1024, 512),
126
+ bidirectional: bool = True,
127
+ qdrop_p: float = 0.5,
128
+ order_by: int | None = 0,
129
+ ) -> None:
130
+ """
131
+ Parameters
132
+ ----------
133
+ doc_embedding_size : int
134
+ Dimensionality of document embeddings (output of `learner.transform`).
135
+ n_classes : int
136
+ Number of classes of the quantification problem.
137
+ stats_size : int
138
+ Dimensionality of the statistics vector concatenated to the LSTM embedding
139
+ (e.g. concatenated prevalence estimates from CC, ACC, PCC, PACC, EMQ, ...).
140
+ lstm_hidden_size : int, default=64
141
+ Hidden size of the LSTM cell(s).
142
+ lstm_nlayers : int, default=1
143
+ Number of stacked LSTM layers.
144
+ ff_layers : sequence of int, default=(1024, 512)
145
+ Sizes of the fully connected layers on top of the quantification embedding.
146
+ bidirectional : bool, default=True
147
+ Whether to use a bidirectional LSTM.
148
+ qdrop_p : float, default=0.5
149
+ Dropout probability used in the LSTM and in the fully connected layers.
150
+ order_by : int or None, default=0
151
+ Index of the class whose posterior probability is used for sorting the sequence.
152
+ If None, no sorting is performed.
153
+ """
154
+ super().__init__()
155
+
156
+ self.n_classes = n_classes
157
+ self.order_by = order_by
158
+ self.hidden_size = lstm_hidden_size
159
+ self.nlayers = lstm_nlayers
160
+ self.bidirectional = bidirectional
161
+ self.ndirections = 2 if self.bidirectional else 1
162
+ self.qdrop_p = qdrop_p
163
+ self.lstm = torch.nn.LSTM(doc_embedding_size + n_classes, # +n_classes stands for the posterior probs. (concatenated)
164
+ lstm_hidden_size, lstm_nlayers, bidirectional=bidirectional,
165
+ dropout=qdrop_p, batch_first=True)
166
+ self.dropout = torch.nn.Dropout(self.qdrop_p)
167
+
168
+ lstm_output_size = self.hidden_size * self.ndirections
169
+ ff_input_size = lstm_output_size + stats_size
170
+ prev_size = ff_input_size
171
+ self.ff_layers = torch.nn.ModuleList()
172
+ for lin_size in ff_layers:
173
+ self.ff_layers.append(torch.nn.Linear(prev_size, lin_size))
174
+ prev_size = lin_size
175
+ self.output = torch.nn.Linear(prev_size, n_classes)
176
+
177
+ @property
178
+ def device(self) -> torch.device:
179
+ """Return the device on which the module parameters are stored."""
180
+ return next(self.parameters()).device
181
+
182
+ def _init_hidden(self) -> tuple[torch.Tensor, torch.Tensor]:
183
+ """
184
+ Initialize LSTM hidden and cell states with zeros.
185
+
186
+ Returns
187
+ -------
188
+ (h0, c0) : (Tensor, Tensor)
189
+ Initial hidden and cell states.
190
+ """
191
+ directions = 2 if self.bidirectional else 1
192
+ var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
193
+ var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
194
+ if next(self.lstm.parameters()).is_cuda:
195
+ var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
196
+ return var_hidden, var_cell
197
+
198
+ def forward(
199
+ self,
200
+ doc_embeddings: np.ndarray | torch.Tensor,
201
+ doc_posteriors: np.ndarray | torch.Tensor,
202
+ statistics: np.ndarray | torch.Tensor,
203
+ ) -> torch.Tensor:
204
+ """
205
+ Forward pass of QuaNet.
206
+
207
+ Parameters
208
+ ----------
209
+ doc_embeddings : array-like of shape (n_docs, emb_dim)
210
+ Document embeddings of all items in the bag.
211
+ doc_posteriors : array-like of shape (n_docs, n_classes)
212
+ Posterior probabilities `P(y | x)` for each document in the bag, produced by the base classifier.
213
+ statistics : array-like of shape (stats_size,) or (1, stats_size)
214
+ Vector of quantification-related statistics (e.g., CC/ACC/PCC/PACC estimates, TPR/FPR, etc.).
215
+
216
+ Returns
217
+ -------
218
+ prevalence : torch.Tensor of shape (1, n_classes)
219
+ Estimated class-prevalence vector for the input bag.
220
+ """
221
+ device = self.device
222
+ doc_embeddings = torch.as_tensor(doc_embeddings, dtype=torch.float, device=device)
223
+ doc_posteriors = torch.as_tensor(doc_posteriors, dtype=torch.float, device=device)
224
+ statistics = torch.as_tensor(statistics, dtype=torch.float, device=device)
225
+
226
+ if self.order_by is not None:
227
+ order = torch.argsort(doc_posteriors[:, self.order_by])
228
+ doc_embeddings = doc_embeddings[order]
229
+ doc_posteriors = doc_posteriors[order]
230
+
231
+ embeded_posteriors = torch.cat((doc_embeddings, doc_posteriors), dim=-1)
232
+
233
+ # the entire set represents only one instance in quapy contexts, and so the batch_size=1
234
+ # the shape should be (1, number-of-instances, embedding-size + n_classes)
235
+
236
+ embeded_posteriors = embeded_posteriors.unsqueeze(0)
237
+
238
+ self.lstm.flatten_parameters()
239
+ _, (rnn_hidden,_) = self.lstm(embeded_posteriors, self._init_hidden())
240
+ rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
241
+ quant_embedding = rnn_hidden[0].view(-1)
242
+ quant_embedding = torch.cat((quant_embedding, statistics))
243
+
244
+ abstracted = quant_embedding.unsqueeze(0)
245
+
246
+ for linear in self.ff_layers:
247
+ abstracted = self.dropout(relu(linear(abstracted)))
248
+
249
+ logits = self.output(abstracted).view(1, -1)
250
+ prevalence = torch.softmax(logits, -1)
251
+
252
+ return prevalence
253
+
254
+
255
+ class QuaNet(SoftLearnerQMixin, AggregationMixin, BaseQuantifier):
256
+ r"""
257
+ QuaNetQuantifier: a deep quantification method following the QuaNet architecture,
258
+ implemented in the `mlquantify` style.
259
+
260
+ This class wraps a base probabilistic learner that:
261
+ - can be trained on labeled instances,
262
+ - can output posterior probabilities via `predict_proba(X)`,
263
+ - can transform instances into embeddings via `transform(X)`.
264
+
265
+ QuaNet then learns a mapping from bags of instances to class-prevalence vectors by:
266
+ - generating artificial bags using the APP protocol (APP: Artificial Prevalence Protocol),
267
+ - computing simple quantification estimates (CC, ACC, PCC, PACC, ...) on each bag,
268
+ - feeding the sequence of (embedding, posterior) pairs and the statistics vector into an LSTM-based network,
269
+ - minimizing a bag-level quantification loss (MSE between predicted and true prevalences).[file:1][file:3]
270
+
271
+ Parameters
272
+ ----------
273
+ learner : estimator
274
+ Base probabilistic classifier. Must implement:
275
+ - fit(X, y),
276
+ - predict_proba(X) -> array-like (n_samples, n_classes),
277
+ - transform(X) -> array-like (n_samples, emb_dim).
278
+ fit_learner : bool, default=True
279
+ If True, the learner is trained inside QuaNetQuantifier.fit.
280
+ If False, it is assumed to be already fitted.
281
+ sample_size : int, default=100
282
+ Bag size used by the APP protocol during QuaNet training.
283
+ n_epochs : int, default=100
284
+ Maximum number of QuaNet training epochs.
285
+ tr_iter : int, default=500
286
+ Number of APP samplings (training iterations) per epoch.
287
+ va_iter : int, default=100
288
+ Number of APP samplings (validation iterations) per epoch.
289
+ lr : float, default=1e-3
290
+ Learning rate for the Adam optimizer.
291
+ lstm_hidden_size : int, default=64
292
+ Hidden size of the QuaNet LSTM.
293
+ lstm_nlayers : int, default=1
294
+ Number of layers in the QuaNet LSTM.
295
+ ff_layers : sequence of int, default=(1024, 512)
296
+ Sizes of the fully connected layers on top of the LSTM quantification embedding.
297
+ bidirectional : bool, default=True
298
+ Whether to use a bidirectional LSTM.
299
+ qdrop_p : float, default=0.5
300
+ Dropout probability used in QuaNet.
301
+ patience : int, default=10
302
+ Early-stopping patience in number of epochs without validation improvement.
303
+ checkpointdir : str, default="./checkpoint_quanet"
304
+ Directory where intermediate QuaNet weights are stored.
305
+ checkpointname : str or None, default=None
306
+ Name of the saved checkpoint file. If None, a random name is generated.
307
+ device : {"cpu", "cuda"}, default="cuda"
308
+ Device on which to run the QuaNet model.
309
+ """
310
+
311
+ _parameter_constraints = {
312
+ "fit_learner": [Interval(0, None, inclusive_left=False), Options([None])],
313
+ "sample_size": [Interval(0, None, inclusive_left=False), Options([None])],
314
+ "n_epochs": [Interval(0, None, inclusive_left=False), Options([None])],
315
+ "tr_iter": [Interval(0, None, inclusive_left=False), Options([None])],
316
+ "va_iter": [Interval(0, None, inclusive_left=False), Options([None])],
317
+ "lr": [Interval(0, None, inclusive_left=False), Options([None])],
318
+ "lstm_hidden_size": [Interval(0, None, inclusive_left=False), Options([None])],
319
+ "lstm_nlayers": [Interval(0, None, inclusive_left=False), Options([None])],
320
+ "bidirectional": [Interval(0, None, inclusive_left=False), Options([None])],
321
+ "qdrop_p": [Interval(0, None, inclusive_left=False), Options([None])],
322
+ "patience": [Interval(0, None, inclusive_left=False), Options([None])],
323
+ "checkpointdir": ["string"],
324
+ "checkpointname": ["string"],
325
+ }
326
+
327
+
328
+ def __init__(
329
+ self,
330
+ learner,
331
+ fit_learner: bool = True,
332
+ sample_size: int = 100,
333
+ n_epochs: int = 100,
334
+ tr_iter: int = 500,
335
+ va_iter: int = 100,
336
+ lr: float = 1e-3,
337
+ lstm_hidden_size: int = 64,
338
+ lstm_nlayers: int = 1,
339
+ ff_layers: Sequence[int] = (1024, 512),
340
+ bidirectional: bool = True,
341
+ random_state: int = None,
342
+ qdrop_p: float = 0.5,
343
+ patience: int = 10,
344
+ checkpointdir: str = "./checkpoint_quanet",
345
+ checkpointname: str | None = None,
346
+ device: str = "cuda",
347
+ ) -> None:
348
+
349
+ assert hasattr(learner, "transform"), ...
350
+ assert hasattr(learner, "predict_proba"), ...
351
+
352
+ # save hyperparameters as attributes
353
+ self.learner = learner
354
+ self.fit_learner = fit_learner
355
+ self.sample_size = sample_size
356
+ self.n_epochs = n_epochs
357
+ self.tr_iter = tr_iter
358
+ self.va_iter = va_iter
359
+ self.lr = lr
360
+ self.lstm_hidden_size = lstm_hidden_size
361
+ self.lstm_nlayers = lstm_nlayers
362
+ self.ff_layers = ff_layers
363
+ self.bidirectional = bidirectional
364
+ self.random_state = random_state
365
+ self.qdrop_p = qdrop_p
366
+ self.patience = patience
367
+ self.checkpointdir = checkpointdir
368
+ self.checkpointname = checkpointname
369
+ self.device = torch.device(device)
370
+
371
+ self.quanet_params: Dict[str, Any] = dict(
372
+ lstm_hidden_size=lstm_hidden_size,
373
+ lstm_nlayers=lstm_nlayers,
374
+ ff_layers=ff_layers,
375
+ bidirectional=bidirectional,
376
+ qdrop_p=qdrop_p,
377
+ )
378
+
379
+ os.makedirs(self.checkpointdir, exist_ok=True)
380
+ if self.checkpointname is None:
381
+ local_random = random.Random()
382
+ random_code = "-".join(str(local_random.randint(0, 1_000_000)) for _ in range(5))
383
+ self.checkpointname = f"QuaNet-{random_code}"
384
+ self.checkpoint = os.path.join(self.checkpointdir, self.checkpointname)
385
+
386
+ self._classes_ = None
387
+ self.quantifiers = {}
388
+ self.quanet = None
389
+ self.optim = None
390
+
391
+ self.status: Dict[str, float] = {
392
+ "tr-loss": -1.0,
393
+ "va-loss": -1.0,
394
+ "tr-mae": -1.0,
395
+ "va-mae": -1.0,
396
+ }
397
+
398
+ @_fit_context(prefer_skip_nested_validation=True)
399
+ def fit(self, X, y):
400
+ y = validate_data(self, y=y)
401
+ self.classes_ = check_classes_attribute(self, np.unique(y))
402
+
403
+ os.makedirs(self.checkpointdir, exist_ok=True)
404
+
405
+ if self.fit_learner:
406
+ X_clf, X_rest, y_clf, y_rest = train_test_split(
407
+ X, y, test_size=0.4, random_state=self.random_state, stratify=y
408
+ )
409
+ X_train, X_val, y_train, y_val = train_test_split(
410
+ X_rest, y_rest, test_size=0.2, random_state=self.random_state, stratify=y_rest
411
+ )
412
+
413
+ self.learner.fit(X_clf, y_clf)
414
+ else:
415
+ X_train, X_val, y_train, y_val = train_test_split(
416
+ X, y, test_size=0.40, random_state=self.random_state, stratify=y
417
+ )
418
+
419
+ self.tr_prev = get_prev_from_labels(y, format="array")
420
+
421
+ # **CORREÇÃO: Obter embeddings e suas dimensões**
422
+ X_train_embeddings = self.learner.transform(X_train)
423
+ X_val_embeddings = self.learner.transform(X_val)
424
+
425
+ valid_posteriors = self.learner.predict_proba(X_val)
426
+ train_posteriors = self.learner.predict_proba(X_train)
427
+
428
+ self.val_posteriors = valid_posteriors
429
+ self.y_val = y_val
430
+
431
+ self.quantifiers = {
432
+ "cc": CC(self.learner),
433
+ "acc": AC(self.learner),
434
+ "pcc": PCC(self.learner),
435
+ "pacc": PAC(self.learner),
436
+ "emq": EMQ(self.learner),
437
+ }
438
+
439
+ self.status = {
440
+ "tr-loss": -1.0,
441
+ "va-loss": -1.0,
442
+ "tr-mae": -1.0,
443
+ "va-mae": -1.0,
444
+ }
445
+
446
+ numQtf = len(self.quantifiers)
447
+ numClasses = len(self.classes_)
448
+
449
+ # **CORREÇÃO: Use a dimensão dos embeddings, não das features originais**
450
+ self.quanet = QuaNetModule(
451
+ doc_embedding_size=X_train_embeddings.shape[1], # ← MUDANÇA AQUI
452
+ n_classes=numClasses,
453
+ stats_size=numQtf*numClasses,
454
+ order_by=0 if numClasses == 2 else None,
455
+ **self.quanet_params
456
+ ).to(self.device)
457
+ print(self.quanet)
458
+
459
+ self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
460
+ early_stop = EarlyStop(
461
+ patience=self.patience,
462
+ lower_is_better=True,
463
+ )
464
+
465
+ checkpoint = self.checkpoint
466
+
467
+ for epoch in range(self.n_epochs):
468
+ # **CORREÇÃO: Passar embeddings em vez de X original**
469
+ self._epoch(
470
+ X_train_embeddings, y_train, train_posteriors,
471
+ self.tr_iter, epoch, early_stop, train=True
472
+ )
473
+ self._epoch(
474
+ X_val_embeddings, y_val, valid_posteriors,
475
+ self.va_iter, epoch, early_stop, train=False
476
+ )
477
+
478
+ early_stop(self.status["va-loss"], epoch)
479
+ if early_stop.IMPROVED:
480
+ torch.save(self.quanet.state_dict(), checkpoint)
481
+ elif early_stop.STOP:
482
+ print(f'Training ended at epoch {early_stop.best_epoch}, loading best model parameters in {checkpoint}')
483
+ self.quanet.load_state_dict(torch.load(checkpoint))
484
+ break
485
+
486
+ return self
487
+
488
+ def _aggregate_qtf(self, posteriors, train_posteriors, y_train):
489
+ qtf_estims = []
490
+
491
+ for name, qtf in self.quantifiers.items():
492
+
493
+ requirements = get_aggregation_requirements(qtf)
494
+
495
+ if requirements.requires_train_proba and requirements.requires_train_labels:
496
+ prev = qtf.aggregate(posteriors, train_posteriors, y_train)
497
+ elif requirements.requires_train_labels:
498
+ prev = qtf.aggregate(posteriors, y_train)
499
+ else:
500
+ prev = qtf.aggregate(posteriors)
501
+
502
+ qtf_estims.extend(np.asarray(list(prev.values())))
503
+
504
+ return qtf_estims
505
+
506
+
507
+ def predict(self, X):
508
+
509
+ learner_function = _get_learner_function(self)
510
+ posteriors = getattr(self.learner, learner_function)(X)
511
+ embeddings = self.learner.transform(X)
512
+
513
+ qtf_estims = self._aggregate_qtf(posteriors, self.val_posteriors, self.y_val)
514
+
515
+ self.quanet.eval()
516
+ with torch.no_grad():
517
+ prevalence = self.quanet.forward(embeddings, posteriors, qtf_estims)
518
+ if self.device.type == "cuda":
519
+ prevalence = prevalence.cpu()
520
+ prevalence = prevalence.numpy().flatten()
521
+
522
+ return prevalence
523
+
524
+
525
+ def _epoch(self, X, y, posteriors, iterations, epoch, early_stop, train: bool) -> None:
526
+ mse_loss = MSELoss()
527
+
528
+ self.quanet.train(mode=train)
529
+ losses = []
530
+ mae_errors = []
531
+
532
+ sampler = UPP(
533
+ batch_size=self.sample_size,
534
+ n_prevalences=iterations,
535
+ random_state= None if train else self.random_state,
536
+ )
537
+
538
+ for idx in sampler.split(X, y):
539
+ X_batch = X[idx]
540
+ y_batch = y[idx]
541
+ posteriors_batch = posteriors[idx]
542
+
543
+ qtf_estims = self._aggregate_qtf(posteriors_batch, self.val_posteriors, self.y_val)
544
+
545
+ p_true = torch.as_tensor(
546
+ get_prev_from_labels(y_batch, format="array", classes=self.classes_),
547
+ dtype=torch.float,
548
+ device=self.device
549
+ ).unsqueeze(0)
550
+
551
+ if train:
552
+ self.optim.zero_grad()
553
+ p_pred = self.quanet.forward(
554
+ X_batch,
555
+ posteriors_batch,
556
+ qtf_estims
557
+ )
558
+ loss = mse_loss(p_pred, p_true)
559
+ mae = mae_loss(p_pred, p_true)
560
+ loss.backward()
561
+ self.optim.step()
562
+ else:
563
+ with torch.no_grad():
564
+ p_pred = self.quanet.forward(
565
+ X_batch,
566
+ posteriors_batch,
567
+ qtf_estims
568
+ )
569
+ loss = mse_loss(p_pred, p_true)
570
+ mae = mae_loss(p_pred, p_true)
571
+
572
+ losses.append(loss.item())
573
+ mae_errors.append(mae.item())
574
+
575
+ mae = np.mean(mae_errors)
576
+ mse = np.mean(losses)
577
+
578
+ if train:
579
+ self.status["tr-mae"] = mae
580
+ self.status["tr-loss"] = mse
581
+ else:
582
+ self.status["va-mae"] = mae
583
+ self.status["va-loss"] = mse
584
+
585
+
586
+ def _check_params_colision(self, quanet_params, learner_params):
587
+ quanet_keys = set(quanet_params.keys())
588
+ learner_keys = set(learner_params.keys())
589
+
590
+ colision_keys = quanet_keys.intersection(learner_keys)
591
+
592
+ if colision_keys:
593
+ raise ValueError(f"Parameters {colision_keys} are present in both quanet_params and learner_params")
594
+
595
+ def clean_checkpoint(self):
596
+ if os.path.exists(self.checkpoint):
597
+ os.remove(self.checkpoint)
598
+
599
+ def clean_checkpoint_dir(self):
600
+ import shutil
601
+ shutil.rmtree(self.checkpointdir, ignore_errors=True)
602
+
603
+
604
+ def mae_loss(y_true, y_pred):
605
+ return torch.mean(torch.abs(y_true - y_pred))
606
+
607
+
608
+
609
+
610
+
611
+
File without changes
File without changes
@@ -43,5 +43,6 @@ from mlquantify.utils._validation import (
43
43
  _is_arraylike_not_scalar,
44
44
  _is_arraylike,
45
45
  validate_data,
46
- check_classes_attribute
46
+ check_classes_attribute,
47
+ validate_prevalences
47
48
  )
@@ -198,6 +198,8 @@ class StringConstraint:
198
198
  return sp.issparse(value)
199
199
  if self.keyword == "boolean":
200
200
  return isinstance(value, bool)
201
+ if self.keyword == "string":
202
+ return isinstance(value, str)
201
203
  if self.keyword == "random_state":
202
204
  return isinstance(value, (np.random.RandomState, int, type(None)))
203
205
  if self.keyword == "nan":
@@ -266,6 +266,13 @@ def _is_arraylike(x):
266
266
  return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__")
267
267
 
268
268
 
269
+ def _transform_if_float(y):
270
+ """Transform y to integers if it is float."""
271
+ if np.issubdtype(y.dtype, np.floating):
272
+ y = y.astype(str)
273
+ return y
274
+
275
+
269
276
  def validate_data(quantifier,
270
277
  X="no_validation",
271
278
  y="no_validation",
@@ -316,8 +323,10 @@ def validate_data(quantifier,
316
323
  if "estimator" not in check_y_params:
317
324
  check_y_params = {**default_check_params, **check_y_params}
318
325
  y = check_array(y, input_name="y", **check_y_params)
326
+ y = _transform_if_float(y)
319
327
  else:
320
328
  X, y = check_X_y(X, y, dtype=None, **check_params)
329
+ y = _transform_if_float(y)
321
330
  out = X, y
322
331
 
323
332
  return out