mlquantify 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ from mlquantify.utils._sampling import (
8
8
  simplex_uniform_kraemer,
9
9
  simplex_uniform_sampling,
10
10
  )
11
+ from mlquantify.utils._random import check_random_state
11
12
  from mlquantify.utils._validation import validate_data
12
13
  from abc import ABC, abstractmethod
13
14
  from logging import warning
@@ -170,6 +171,8 @@ class APP(BaseProtocol):
170
171
  def _iter_indices(self, X: np.ndarray, y: np.ndarray):
171
172
 
172
173
  n_dim = len(np.unique(y))
174
+
175
+ rng = check_random_state(self.random_state)
173
176
 
174
177
  for batch_size in self.batch_size:
175
178
  prevalences = simplex_grid_sampling(n_dim=n_dim,
@@ -178,9 +181,8 @@ class APP(BaseProtocol):
178
181
  min_val=self.min_prev,
179
182
  max_val=self.max_prev)
180
183
  for prev in prevalences:
181
- indexes = get_indexes_with_prevalence(y, prev, batch_size)
184
+ indexes = get_indexes_with_prevalence(y, prev, batch_size, random_state=rng)
182
185
  yield indexes
183
-
184
186
 
185
187
 
186
188
 
@@ -221,10 +223,10 @@ class NPP(BaseProtocol):
221
223
  self.repeats = repeats
222
224
 
223
225
  def _iter_indices(self, X: np.ndarray, y: np.ndarray):
224
-
226
+ rng = check_random_state(self.random_state)
225
227
  for _ in range(self.n_samples):
226
228
  for batch_size in self.batch_size:
227
- idx = np.random.choice(X.shape[0], batch_size, replace=True)
229
+ idx = rng.choice(X.shape[0], batch_size, replace=True)
228
230
  for _ in range(self.repeats):
229
231
  yield idx
230
232
 
@@ -289,6 +291,8 @@ class UPP(BaseProtocol):
289
291
  def _iter_indices(self, X: np.ndarray, y: np.ndarray):
290
292
 
291
293
  n_dim = len(np.unique(y))
294
+
295
+ rng = check_random_state(self.random_state)
292
296
 
293
297
  for batch_size in self.batch_size:
294
298
  if self.algorithm == 'kraemer':
@@ -296,16 +300,17 @@ class UPP(BaseProtocol):
296
300
  n_prev=self.n_prevalences,
297
301
  n_iter=self.repeats,
298
302
  min_val=self.min_prev,
299
- max_val=self.max_prev)
303
+ max_val=self.max_prev,
304
+ random_state=rng)
300
305
  elif self.algorithm == 'uniform':
301
306
  prevalences = simplex_uniform_sampling(n_dim=n_dim,
302
307
  n_prev=self.n_prevalences,
303
308
  n_iter=self.repeats,
304
309
  min_val=self.min_prev,
305
- max_val=self.max_prev)
306
-
310
+ max_val=self.max_prev,
311
+ random_state=rng)
307
312
  for prev in prevalences:
308
- indexes = get_indexes_with_prevalence(y, prev, batch_size)
313
+ indexes = get_indexes_with_prevalence(y, prev, batch_size, random_state=rng)
309
314
  yield indexes
310
315
 
311
316
 
@@ -347,12 +352,12 @@ class PPP(BaseProtocol):
347
352
  repeats=repeats)
348
353
 
349
354
  def _iter_indices(self, X: np.ndarray, y: np.ndarray):
350
-
355
+ rng = check_random_state(self.random_state)
351
356
  for batch_size in self.batch_size:
352
357
  for prev in self.prevalences:
353
358
  if isinstance(prev, float):
354
359
  prev = [1-prev, prev]
355
360
 
356
- indexes = get_indexes_with_prevalence(y, prev, batch_size)
361
+ indexes = get_indexes_with_prevalence(y, prev, batch_size, random_state=rng)
357
362
  yield indexes
358
363
 
@@ -1,8 +1,9 @@
1
1
  import numpy as np
2
+ from mlquantify.utils import check_random_state
2
3
  import itertools
3
4
 
4
5
 
5
- def get_indexes_with_prevalence(y, prevalence: list, sample_size:int):
6
+ def get_indexes_with_prevalence(y, prevalence: list, sample_size:int, random_state: int = None):
6
7
  """
7
8
  Get indexes for a stratified sample based on the prevalence of each class.
8
9
 
@@ -23,6 +24,7 @@ def get_indexes_with_prevalence(y, prevalence: list, sample_size:int):
23
24
  List of indexes for the stratified sample.
24
25
  """
25
26
  classes = np.unique(y)
27
+ rng = check_random_state(random_state)
26
28
 
27
29
  # Ensure the sum of prevalences is 1
28
30
  assert np.isclose(sum(prevalence), 1), "The sum of prevalences must be 1"
@@ -43,12 +45,12 @@ def get_indexes_with_prevalence(y, prevalence: list, sample_size:int):
43
45
  class_indexes = np.where(y == class_)[0]
44
46
 
45
47
  # Sample the indexes for the current class
46
- sampled_class_indexes = np.random.choice(class_indexes, size=num_samples, replace=True)
48
+ sampled_class_indexes = rng.choice(class_indexes, size=num_samples, replace=True)
47
49
 
48
50
  sampled_indexes.extend(sampled_class_indexes)
49
51
  total_sampled += num_samples
50
52
 
51
- np.random.shuffle(sampled_indexes) # Shuffle after collecting all indexes
53
+ rng.shuffle(sampled_indexes) # Shuffle after collecting all indexes
52
54
 
53
55
  return sampled_indexes
54
56
 
@@ -59,7 +61,8 @@ def simplex_uniform_kraemer(n_dim: int,
59
61
  n_iter: int,
60
62
  min_val: float = 0.0,
61
63
  max_val: float = 1.0,
62
- max_tries: int = 1000) -> np.ndarray:
64
+ max_tries: int = 1000,
65
+ random_state: int = None) -> np.ndarray:
63
66
  """
64
67
  Generates n_prev prevalence vectors of n_dim classes uniformly
65
68
  distributed on the simplex, with optional lower and upper bounds.
@@ -91,28 +94,25 @@ def simplex_uniform_kraemer(n_dim: int,
91
94
  if min_val * n_dim > 1 or max_val * n_dim < 1:
92
95
  raise ValueError("Invalid bounds: they make it impossible to sum to 1.")
93
96
 
97
+ rng = check_random_state(random_state)
98
+
94
99
  effective_simplex_size = 1 - n_dim * min_val
95
100
  prevs = []
96
101
 
97
- # Amostragem em blocos até atingir n_prev válidos
98
102
  tries = 0
99
- batch_size = max(n_prev, 1000) # Gera em blocos grandes para eficiência
103
+ batch_size = n_prev
100
104
 
101
105
  while len(prevs) < n_prev and tries < max_tries:
102
106
  tries += 1
103
-
104
- # Geração de pontos uniformes no simplex reduzido
105
- u = np.random.uniform(0, 1, (batch_size, n_dim - 1))
107
+
108
+ u = rng.uniform(0, 1, (batch_size, n_dim - 1))
106
109
  u.sort(axis=1)
107
110
  simplex = np.diff(np.concatenate([np.zeros((batch_size, 1)), u, np.ones((batch_size, 1))], axis=1), axis=1)
108
111
 
109
- # Escala para [min_val, max_val]
110
112
  scaled = min_val + simplex * effective_simplex_size
111
113
 
112
- # Normaliza para garantir soma = 1
113
114
  scaled /= scaled.sum(axis=1, keepdims=True)
114
115
 
115
- # Filtra apenas vetores válidos
116
116
  mask = np.all((scaled >= min_val) & (scaled <= max_val), axis=1)
117
117
  valid = scaled[mask]
118
118
 
@@ -122,11 +122,13 @@ def simplex_uniform_kraemer(n_dim: int,
122
122
  if not prevs:
123
123
  raise RuntimeError("No valid prevalences found with given constraints. Try adjusting min_val/max_val.")
124
124
 
125
- if n_iter > 1:
126
- prevs = np.tile(prevs, (n_iter, 1))
127
-
128
125
  result = np.vstack(prevs)
129
- return result[:n_prev]
126
+ result = result[:n_prev]
127
+
128
+ if n_iter > 1:
129
+ result = np.repeat(result, n_iter, axis=0)
130
+
131
+ return result
130
132
 
131
133
 
132
134
 
@@ -135,7 +137,7 @@ def simplex_grid_sampling(
135
137
  n_prev: int,
136
138
  n_iter: int,
137
139
  min_val: float,
138
- max_val: float
140
+ max_val: float,
139
141
  ) -> np.ndarray:
140
142
  """
141
143
  Efficiently generates artificial prevalence vectors that sum to 1
@@ -181,7 +183,7 @@ def simplex_grid_sampling(
181
183
 
182
184
  # Repetição se necessário
183
185
  if n_iter > 1:
184
- prevs = np.tile(prevs, (n_iter, 1))
186
+ prevs = np.repeat(prevs, n_iter, axis=0)
185
187
 
186
188
  return prevs
187
189
 
@@ -193,7 +195,8 @@ def simplex_uniform_sampling(
193
195
  n_prev: int,
194
196
  n_iter: int,
195
197
  min_val: float,
196
- max_val: float
198
+ max_val: float,
199
+ random_state: int = None
197
200
  ) -> np.ndarray:
198
201
  """
199
202
  Generates uniformly distributed prevalence vectors within the simplex,
@@ -265,9 +268,8 @@ def bootstrap_sample_indices(
265
268
  np.ndarray
266
269
  Array containing indices for a bootstrap sample.
267
270
  """
268
- if random_state is not None:
269
- np.random.seed(random_state)
271
+ rng = check_random_state(random_state)
270
272
 
271
273
  for _ in range(n_bootstraps):
272
- indices = np.random.choice(n_samples, size=batch_size, replace=True)
274
+ indices = rng.choice(n_samples, size=batch_size, replace=True)
273
275
  yield indices
@@ -3,27 +3,36 @@ import pandas as pd
3
3
  from collections import defaultdict
4
4
 
5
5
 
6
- def get_prev_from_labels(y, format="dict") -> dict:
6
+ def get_prev_from_labels(y, format="dict", classes: list = None):
7
7
  """
8
8
  Get the real prevalence of each class in the target array.
9
-
9
+
10
10
  Parameters
11
11
  ----------
12
12
  y : np.ndarray or pd.Series
13
13
  Array of class labels.
14
-
14
+ format : str, default="dict"
15
+ Format of the output. Can be "array" or "dict".
16
+ classes : list, optional
17
+ List of unique classes. If provided, the output will be sorted by these classes.
18
+
15
19
  Returns
16
20
  -------
17
- dict
18
- Dictionary of class labels and their corresponding prevalence.
21
+ dict or np.ndarray
22
+ Dictionary of class labels and their corresponding prevalence or array of prevalences.
19
23
  """
20
24
  if isinstance(y, np.ndarray):
21
25
  y = pd.Series(y)
26
+
27
+ counts = y.value_counts(normalize=True).sort_index()
28
+
29
+ if classes is not None:
30
+ counts = counts.reindex(classes, fill_value=0.0)
31
+
22
32
  if format == "array":
23
- prevalences = y.value_counts(normalize=True).sort_index().values
24
- return prevalences
25
- real_prevs = y.value_counts(normalize=True).to_dict()
26
- real_prevs = dict(sorted(real_prevs.items()))
33
+ return counts.values
34
+
35
+ real_prevs = counts.to_dict()
27
36
  return real_prevs
28
37
 
29
38
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlquantify
3
- Version: 0.1.18
3
+ Version: 0.1.19
4
4
  Summary: Quantification Library
5
5
  Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
6
6
  Maintainer: Luiz Fernando Luth Junior
@@ -23,7 +23,7 @@ mlquantify/mixture/_base.py,sha256=1-yW64FPQXB_d9hH9KjSlDnmFtW9FY7S2hppXAd1DBg,5
23
23
  mlquantify/mixture/_classes.py,sha256=Dx0KWS-RtVVmJwXvPKIVWitsJhgcYRRiypLYrgE66x4,16420
24
24
  mlquantify/mixture/_utils.py,sha256=CKlC081nrkJ8Pil7lrPZvNZC_xfpXV8SsuQq3M_LHgA,4037
25
25
  mlquantify/model_selection/__init__.py,sha256=98I0uf8k6lbWAjazGyGjbOdPOvzU8aMRLqC3I7D3jzk,113
26
- mlquantify/model_selection/_protocol.py,sha256=2k0M_7YwZf7YLoQ8ElR2xMvLySVgtE_EvWieMXTIzTA,12499
26
+ mlquantify/model_selection/_protocol.py,sha256=XhkNUN-XAuGkihm0jwQL665ps2G9bevxme_yrETNQHo,12902
27
27
  mlquantify/model_selection/_search.py,sha256=1UoP3tZ-pdfM25C-gOS89qjGKcDgQEeU7GTbwtsLKHU,10695
28
28
  mlquantify/model_selection/_split.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
29
29
  mlquantify/neighbors/__init__.py,sha256=rIOuSaUhjqEXsUN9HNZ62P53QG0N7lJ3j1pvf8kJzms,93
@@ -43,11 +43,11 @@ mlquantify/utils/_get_scores.py,sha256=VlTvgg_t4D9MzcgsH7YvP_wIL5AZ8XmEtGpbFivdV
43
43
  mlquantify/utils/_load.py,sha256=cMGXIs-8mUB4blAmagyDNNvAaV2hysRgeInQMl5fDHg,303
44
44
  mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM,196
45
45
  mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
46
- mlquantify/utils/_sampling.py,sha256=QQxE2WKLdiCFUfPF6fKgzyrsOUIWYf74w_w8fbYVc2c,8409
46
+ mlquantify/utils/_sampling.py,sha256=3W0vUuvLvoYrt-BZpSM0HM1XJEZr0XYIdkOcUP5hp-8,8350
47
47
  mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
48
48
  mlquantify/utils/_validation.py,sha256=zn4OHfa704YBaPKskhiThUG7wS5fvDoHBpcEgb1i8qM,18078
49
- mlquantify/utils/prevalence.py,sha256=FXLCJViQb2yDbyTXeGZt8WsPPnSZINhorQYZTKXOn14,1772
50
- mlquantify-0.1.18.dist-info/METADATA,sha256=XrQ188Icw5RZEAN8tvHRHTsRm1IKB1iwR_tm6G7uB0w,4701
51
- mlquantify-0.1.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
- mlquantify-0.1.18.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
- mlquantify-0.1.18.dist-info/RECORD,,
49
+ mlquantify/utils/prevalence.py,sha256=LG-KXJ5Eb4w26WMpu4PoBpxMSHaqrmTQqdRlyqNRJ1o,2020
50
+ mlquantify-0.1.19.dist-info/METADATA,sha256=nQ0BqrdrpxbBTHhFh6p2M9qXqQsehRAdqIB5cpNbr1s,4701
51
+ mlquantify-0.1.19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
+ mlquantify-0.1.19.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
+ mlquantify-0.1.19.dist-info/RECORD,,