statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
1
+ """K-Means clustering."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional, Union
6
+
7
+ import numpy as np
8
+ from scipy import sparse
9
+
10
+ from statgpu._base import BaseEstimator
11
+ from statgpu._config import Device
12
+ from statgpu.unsupervised._utils import check_2d_array, scalar_to_float, scalar_to_int
13
+
14
+
15
+ class KMeans(BaseEstimator):
16
+ """
17
+ Lloyd K-Means clustering with NumPy, CuPy, or Torch backends.
18
+
19
+ Parameters
20
+ ----------
21
+ n_clusters : int, default=8
22
+ Number of clusters.
23
+ init : {'k-means++', 'random'}, default='k-means++'
24
+ Initialization strategy. ``'k-means++'`` uses greedy local trials,
25
+ matching the practical variant used by scikit-learn.
26
+ n_init : 'auto' or int, default='auto'
27
+ Number of initializations. ``'auto'`` means 1 for k-means++ and 10
28
+ for random initialization.
29
+ max_iter : int, default=300
30
+ Maximum Lloyd iterations per initialization.
31
+ tol : float, default=1e-4
32
+ Absolute convergence tolerance on center movement.
33
+ random_state : int or None, default=None
34
+ Random seed for deterministic initialization.
35
+ device : {'auto', 'cpu', 'cuda', 'torch'}, default='auto'
36
+ Compute device.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ n_clusters: int = 8,
42
+ init: str = "k-means++",
43
+ n_init: Union[str, int] = "auto",
44
+ max_iter: int = 300,
45
+ tol: float = 1e-4,
46
+ random_state: Optional[int] = None,
47
+ device: Union[str, Device] = Device.AUTO,
48
+ n_jobs: Optional[int] = None,
49
+ ):
50
+ super().__init__(device=device, n_jobs=n_jobs)
51
+ self.n_clusters = n_clusters
52
+ self.init = init
53
+ self.n_init = n_init
54
+ self.max_iter = max_iter
55
+ self.tol = tol
56
+ self.random_state = random_state
57
+
58
+ def _validate_params(self, n_samples: int):
59
+ if not isinstance(self.n_clusters, (int, np.integer)) or int(self.n_clusters) < 1:
60
+ raise ValueError("n_clusters must be a positive integer")
61
+ n_clusters = int(self.n_clusters)
62
+ if n_clusters > n_samples:
63
+ raise ValueError("n_clusters must be less than or equal to n_samples")
64
+ if self.init not in ("k-means++", "random"):
65
+ if callable(self.init):
66
+ raise NotImplementedError("callable init is not supported in KMeans v1")
67
+ raise ValueError("init must be one of: 'k-means++', 'random'")
68
+ if self.n_init == "auto":
69
+ n_init = 1 if self.init == "k-means++" else 10
70
+ else:
71
+ if not isinstance(self.n_init, (int, np.integer)) or int(self.n_init) < 1:
72
+ raise ValueError("n_init must be 'auto' or a positive integer")
73
+ n_init = int(self.n_init)
74
+ if not isinstance(self.max_iter, (int, np.integer)) or int(self.max_iter) < 1:
75
+ raise ValueError("max_iter must be a positive integer")
76
+ if float(self.tol) < 0.0:
77
+ raise ValueError("tol must be non-negative")
78
+ return n_clusters, n_init
79
+
80
+ def _squared_distances(self, backend, X, centers):
81
+ x_norm = backend.sum(X * X, axis=1, keepdims=True)
82
+ return self._squared_distances_with_x_norm(backend, X, x_norm, centers)
83
+
84
+ def _squared_distances_with_x_norm(self, backend, X, x_norm, centers):
85
+ c_norm = backend.sum(centers * centers, axis=1, keepdims=False)
86
+ c_norm = backend.reshape(c_norm, (1, centers.shape[0]))
87
+ distances = x_norm + c_norm - 2.0 * backend.matmul(X, centers.T)
88
+ return backend.maximum(distances, 0.0)
89
+
90
+ def _init_centers(self, backend, X, rng, n_clusters, x_norm):
91
+ n_samples = X.shape[0]
92
+ if self.init == "random":
93
+ indices = rng.choice(n_samples, size=n_clusters, replace=False)
94
+ indices_backend = backend.asarray(indices, dtype=backend.int64)
95
+ return backend.copy(X[indices_backend])
96
+
97
+ first = int(rng.integers(0, n_samples))
98
+ centers = [backend.copy(X[first])]
99
+ closest_dist_sq = self._squared_distances_with_x_norm(backend, X, x_norm, backend.reshape(centers[0], (1, X.shape[1])))[:, 0]
100
+ selected = [first]
101
+ n_local_trials = 2 + int(np.log(n_clusters))
102
+ for _ in range(1, n_clusters):
103
+ probs = backend.to_numpy(closest_dist_sq).astype(np.float64, copy=False)
104
+ total = float(np.sum(probs))
105
+ if total <= 0.0 or not np.isfinite(total):
106
+ candidates = np.setdiff1d(np.arange(n_samples), np.asarray(selected))
107
+ next_idx = int(rng.choice(candidates)) if candidates.size else first
108
+ centers.append(backend.copy(X[next_idx]))
109
+ new_dist_sq = self._squared_distances_with_x_norm(backend, X, x_norm, backend.reshape(centers[-1], (1, X.shape[1])))[:, 0]
110
+ closest_dist_sq = backend.minimum(closest_dist_sq, new_dist_sq)
111
+ else:
112
+ candidate_ids = rng.choice(n_samples, size=n_local_trials, replace=True, p=probs / total)
113
+ candidate_ids_backend = backend.asarray(candidate_ids, dtype=backend.int64)
114
+ candidate_centers = X[candidate_ids_backend]
115
+ candidate_dist_sq = self._squared_distances_with_x_norm(backend, X, x_norm, candidate_centers)
116
+ trial_dist_sq = backend.minimum(backend.expand_dims(closest_dist_sq, axis=1), candidate_dist_sq)
117
+ potentials = backend.sum(trial_dist_sq, axis=0)
118
+ best_trial = scalar_to_int(backend.argmin(potentials, axis=0))
119
+ next_idx = int(candidate_ids[best_trial])
120
+ centers.append(backend.copy(candidate_centers[best_trial]))
121
+ closest_dist_sq = trial_dist_sq[:, best_trial]
122
+ selected.append(next_idx)
123
+ return backend.stack(centers, axis=0)
124
+
125
+ def _labels_min_distances(self, backend, distances):
126
+ labels = backend.argmin(distances, axis=1)
127
+ gathered = backend.take_along_axis(distances, backend.expand_dims(labels, axis=1), axis=1)
128
+ return labels, backend.reshape(gathered, (distances.shape[0],))
129
+
130
+ def _bincount(self, backend, labels, n_clusters):
131
+ if backend.name == "torch":
132
+ return backend.xp.bincount(labels, minlength=n_clusters)
133
+ return backend.xp.bincount(labels, minlength=n_clusters)
134
+
135
+ def _label_sums(self, backend, X, labels, n_clusters):
136
+ if backend.name == "numpy":
137
+ if X.shape[0] * n_clusters <= 5_000_000:
138
+ indicator = np.zeros((X.shape[0], n_clusters), dtype=X.dtype)
139
+ indicator[np.arange(X.shape[0]), labels] = 1.0
140
+ return indicator.T @ X
141
+ return np.stack(
142
+ [np.bincount(labels, weights=X[:, j], minlength=n_clusters) for j in range(X.shape[1])],
143
+ axis=1,
144
+ )
145
+ if backend.name == "torch":
146
+ sums = backend.zeros((n_clusters, X.shape[1]), dtype=X.dtype)
147
+ return sums.index_add(0, labels, X)
148
+ sums = backend.zeros((n_clusters, X.shape[1]), dtype=X.dtype)
149
+ backend.xp.add.at(sums, labels, X)
150
+ return sums
151
+
152
+ def _compute_centers(self, backend, X, labels, min_dist_sq, centers):
153
+ n_clusters = int(centers.shape[0])
154
+ counts = self._bincount(backend, labels, n_clusters)
155
+ sums = self._label_sums(backend, X, labels, n_clusters)
156
+ safe_counts = backend.maximum(backend.reshape(counts, (n_clusters, 1)), 1)
157
+ new_centers = sums / safe_counts
158
+
159
+ empty = backend.to_numpy(counts == 0)
160
+ if np.any(empty):
161
+ empty_idx = np.flatnonzero(empty)
162
+ n_empty = int(empty_idx.shape[0])
163
+ # Assign distinct replacement samples to each empty cluster to avoid
164
+ # duplicated centroids that can remain empty due to argmin tie-breaking.
165
+ replacement_idx = np.argsort(backend.to_numpy(min_dist_sq))[::-1][:n_empty]
166
+ int_dtype = getattr(backend, "int64", getattr(backend.xp, "int64", np.int64))
167
+ empty_backend = backend.asarray(empty_idx, dtype=int_dtype)
168
+ replacement_backend = backend.asarray(replacement_idx, dtype=int_dtype)
169
+ new_centers[empty_backend] = X[replacement_backend]
170
+ return new_centers
171
+
172
+ def _run_single(self, backend, X, rng, n_clusters):
173
+ x_norm = backend.sum(X * X, axis=1, keepdims=True)
174
+ centers = self._init_centers(backend, X, rng, n_clusters, x_norm)
175
+ labels = None
176
+ min_dist_sq = None
177
+ n_iter = 0
178
+ for n_iter in range(1, int(self.max_iter) + 1):
179
+ distances = self._squared_distances_with_x_norm(backend, X, x_norm, centers)
180
+ labels, min_dist_sq = self._labels_min_distances(backend, distances)
181
+ new_centers = self._compute_centers(backend, X, labels, min_dist_sq, centers)
182
+ center_shift = backend.sum((new_centers - centers) ** 2)
183
+ centers = new_centers
184
+ if scalar_to_float(center_shift) <= float(self.tol):
185
+ break
186
+ distances = self._squared_distances_with_x_norm(backend, X, x_norm, centers)
187
+ labels, min_dist_sq = self._labels_min_distances(backend, distances)
188
+ inertia = scalar_to_float(backend.sum(min_dist_sq))
189
+ return centers, labels, inertia, n_iter
190
+
191
+ def fit(self, X, y=None, sample_weight=None):
192
+ if sparse.issparse(X):
193
+ raise NotImplementedError("sparse input is not supported in KMeans v1")
194
+ if sample_weight is not None:
195
+ raise NotImplementedError("sample_weight is not supported in KMeans v1")
196
+ backend = self._get_backend()
197
+ X_arr = backend.asarray(X, dtype=backend.float64)
198
+ check_2d_array(X_arr)
199
+ n_samples, n_features = X_arr.shape
200
+ n_clusters, n_init = self._validate_params(n_samples)
201
+
202
+ rng = np.random.default_rng(self.random_state)
203
+ best = None
204
+ for _ in range(n_init):
205
+ centers, labels, inertia, n_iter = self._run_single(backend, X_arr, rng, n_clusters)
206
+ if best is None or inertia < best[2]:
207
+ best = (centers, labels, inertia, n_iter)
208
+
209
+ self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = best
210
+ self.n_features_in_ = int(n_features)
211
+ self._backend_name = backend.name
212
+ self._fitted = True
213
+ return self
214
+
215
+ def transform(self, X):
216
+ self._check_is_fitted()
217
+ if sparse.issparse(X):
218
+ raise NotImplementedError("sparse input is not supported in KMeans v1")
219
+ backend = self._get_backend()
220
+ X_arr = backend.asarray(X, dtype=backend.float64)
221
+ check_2d_array(X_arr)
222
+ if X_arr.shape[1] != self.n_features_in_:
223
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
224
+ return backend.sqrt(self._squared_distances(backend, X_arr, self.cluster_centers_))
225
+
226
+ def predict(self, X):
227
+ self._check_is_fitted()
228
+ if sparse.issparse(X):
229
+ raise NotImplementedError("sparse input is not supported in KMeans v1")
230
+ backend = self._get_backend()
231
+ X_arr = backend.asarray(X, dtype=backend.float64)
232
+ check_2d_array(X_arr)
233
+ if X_arr.shape[1] != self.n_features_in_:
234
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
235
+ distances = self._squared_distances(backend, X_arr, self.cluster_centers_)
236
+ return backend.argmin(distances, axis=1)
237
+
238
+ def fit_predict(self, X, y=None):
239
+ return self.fit(X, y=y).labels_
240
+
241
+ def score(self, X, y=None):
242
+ self._check_is_fitted()
243
+ backend = self._get_backend()
244
+ X_arr = backend.asarray(X, dtype=backend.float64)
245
+ distances = self._squared_distances(backend, X_arr, self.cluster_centers_)
246
+ min_dist_sq = backend.min(distances, axis=1)
247
+ return -scalar_to_float(backend.sum(min_dist_sq))
248
+
249
+ def get_params(self, deep=True):
250
+ params = super().get_params(deep=deep)
251
+ params.update(
252
+ {
253
+ "n_clusters": self.n_clusters,
254
+ "init": self.init,
255
+ "n_init": self.n_init,
256
+ "max_iter": self.max_iter,
257
+ "tol": self.tol,
258
+ "random_state": self.random_state,
259
+ }
260
+ )
261
+ return params
@@ -0,0 +1,299 @@
1
+ """Mini-batch K-Means clustering."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional, Union
6
+
7
+ import numpy as np
8
+ from scipy import sparse
9
+
10
+ from statgpu._base import BaseEstimator
11
+ from statgpu._config import Device
12
+ from statgpu.unsupervised._kmeans import KMeans
13
+ from statgpu.unsupervised._utils import check_2d_array, scalar_to_float
14
+
15
+
16
+ class MiniBatchKMeans(BaseEstimator):
17
+ """Mini-batch Lloyd K-Means with NumPy, CuPy, or Torch backends."""
18
+
19
+ def __init__(
20
+ self,
21
+ n_clusters: int = 8,
22
+ init="k-means++",
23
+ n_init: Union[str, int] = "auto",
24
+ batch_size: int = 1024,
25
+ max_iter: int = 100,
26
+ max_no_improvement: int = 10,
27
+ tol: float = 0.0,
28
+ random_state: Optional[int] = None,
29
+ device: Union[str, Device] = Device.AUTO,
30
+ n_jobs: Optional[int] = None,
31
+ ):
32
+ super().__init__(device=device, n_jobs=n_jobs)
33
+ self.n_clusters = n_clusters
34
+ self.init = init
35
+ self.n_init = n_init
36
+ self.batch_size = batch_size
37
+ self.max_iter = max_iter
38
+ self.max_no_improvement = max_no_improvement
39
+ self.tol = tol
40
+ self.random_state = random_state
41
+
42
+ def _validate_params(self, n_samples: int, n_features: int):
43
+ if not isinstance(self.n_clusters, (int, np.integer)) or int(self.n_clusters) < 1:
44
+ raise ValueError("n_clusters must be a positive integer")
45
+ n_clusters = int(self.n_clusters)
46
+ if n_clusters > n_samples:
47
+ raise ValueError("n_clusters must be less than or equal to n_samples")
48
+ if isinstance(self.init, str):
49
+ if self.init not in ("k-means++", "random"):
50
+ raise ValueError("init must be one of: 'k-means++', 'random', or an array of centers")
51
+ elif callable(self.init):
52
+ raise NotImplementedError("callable init is not supported in MiniBatchKMeans v1")
53
+ else:
54
+ init_shape = getattr(self.init, "shape", None)
55
+ if init_shape != (n_clusters, n_features):
56
+ raise ValueError(
57
+ f"init array must have shape ({n_clusters}, {n_features}), got {init_shape}"
58
+ )
59
+ if self.n_init == "auto":
60
+ n_init = 1 if not isinstance(self.init, str) or self.init == "k-means++" else 3
61
+ else:
62
+ if not isinstance(self.n_init, (int, np.integer)) or int(self.n_init) < 1:
63
+ raise ValueError("n_init must be 'auto' or a positive integer")
64
+ n_init = int(self.n_init)
65
+ if not isinstance(self.batch_size, (int, np.integer)) or int(self.batch_size) < 1:
66
+ raise ValueError("batch_size must be a positive integer")
67
+ if not isinstance(self.max_iter, (int, np.integer)) or int(self.max_iter) < 1:
68
+ raise ValueError("max_iter must be a positive integer")
69
+ if self.max_no_improvement is not None:
70
+ if not isinstance(self.max_no_improvement, (int, np.integer)) or int(self.max_no_improvement) < 0:
71
+ raise ValueError("max_no_improvement must be None or a non-negative integer")
72
+ if float(self.tol) < 0.0:
73
+ raise ValueError("tol must be non-negative")
74
+ return n_clusters, n_init
75
+
76
+ def _helper(self):
77
+ init = self.init if isinstance(self.init, str) else "random"
78
+ return KMeans(
79
+ n_clusters=int(self.n_clusters),
80
+ init=init,
81
+ n_init=1,
82
+ max_iter=1,
83
+ tol=self.tol,
84
+ random_state=self.random_state,
85
+ device=self.device,
86
+ n_jobs=self.n_jobs,
87
+ )
88
+
89
+ def _init_centers(self, backend, X, rng, n_clusters, helper):
90
+ if not isinstance(self.init, str):
91
+ return backend.asarray(self.init, dtype=backend.float64)
92
+ x_norm = backend.sum(X * X, axis=1, keepdims=True)
93
+ return helper._init_centers(backend, X, rng, n_clusters, x_norm)
94
+
95
+ def _single_batch_step(self, backend, X_batch, centers, counts, helper, return_inertia: bool = True):
96
+ distances = helper._squared_distances(backend, X_batch, centers)
97
+ labels, min_dist_sq = helper._labels_min_distances(backend, distances)
98
+ batch_counts = helper._bincount(backend, labels, centers.shape[0])
99
+ batch_sums = helper._label_sums(backend, X_batch, labels, centers.shape[0])
100
+ non_empty = batch_counts > 0
101
+ new_counts = counts + batch_counts
102
+ safe_batch_counts = backend.maximum(backend.reshape(batch_counts, (centers.shape[0], 1)), 1)
103
+ batch_means = batch_sums / safe_batch_counts
104
+ eta = backend.reshape(batch_counts / backend.maximum(new_counts, 1), (centers.shape[0], 1))
105
+ updated = centers + eta * (batch_means - centers)
106
+ centers = backend.where(backend.reshape(non_empty, (centers.shape[0], 1)), updated, centers)
107
+ counts = backend.where(non_empty, new_counts, counts)
108
+ inertia = scalar_to_float(backend.sum(min_dist_sq)) if return_inertia else None
109
+ return centers, counts, labels, inertia
110
+
111
+ def _final_labels_inertia(self, backend, X, centers, helper):
112
+ distances = helper._squared_distances(backend, X, centers)
113
+ labels, min_dist_sq = helper._labels_min_distances(backend, distances)
114
+ return labels, scalar_to_float(backend.sum(min_dist_sq))
115
+
116
+ def _lloyd_polish(self, backend, X, centers, helper, n_steps: int = 2):
117
+ """Run a small exact Lloyd polish after mini-batch updates."""
118
+ labels = None
119
+ inertia = None
120
+ counts = None
121
+ for _ in range(int(n_steps)):
122
+ distances = helper._squared_distances(backend, X, centers)
123
+ labels, min_dist_sq = helper._labels_min_distances(backend, distances)
124
+ new_centers = helper._compute_centers(backend, X, labels, min_dist_sq, centers)
125
+ centers = new_centers
126
+ inertia = scalar_to_float(backend.sum(min_dist_sq))
127
+ counts = helper._bincount(backend, labels, centers.shape[0])
128
+ labels, inertia = self._final_labels_inertia(backend, X, centers, helper)
129
+ counts = helper._bincount(backend, labels, centers.shape[0])
130
+ return centers, counts, labels, inertia
131
+
132
+ def _run_single(self, backend, X, rng, n_clusters, helper):
133
+ centers = self._init_centers(backend, X, rng, n_clusters, helper)
134
+ counts = backend.zeros((n_clusters,), dtype=backend.float64)
135
+ n_samples = X.shape[0]
136
+ batch_size = min(int(self.batch_size), n_samples)
137
+ best_batch_inertia = None
138
+ no_improvement = 0
139
+ n_steps = 0
140
+ n_iter = 0
141
+ last_centers = backend.copy(centers)
142
+ track_improvement = self.max_no_improvement is not None
143
+
144
+ for n_iter in range(1, int(self.max_iter) + 1):
145
+ order = rng.permutation(n_samples)
146
+ order_backend = backend.asarray(order, dtype=backend.int64)
147
+ for start in range(0, n_samples, batch_size):
148
+ batch_idx = order_backend[start : start + batch_size]
149
+ X_batch = X[batch_idx]
150
+ centers, counts, _, batch_inertia = self._single_batch_step(
151
+ backend, X_batch, centers, counts, helper, return_inertia=track_improvement
152
+ )
153
+ n_steps += 1
154
+ if track_improvement:
155
+ if best_batch_inertia is None or batch_inertia < best_batch_inertia:
156
+ best_batch_inertia = batch_inertia
157
+ no_improvement = 0
158
+ else:
159
+ no_improvement += 1
160
+ if no_improvement >= int(self.max_no_improvement):
161
+ break
162
+
163
+ center_shift = scalar_to_float(backend.sum((centers - last_centers) ** 2))
164
+ last_centers = backend.copy(centers)
165
+ if center_shift <= float(self.tol):
166
+ break
167
+ if track_improvement and no_improvement >= int(self.max_no_improvement):
168
+ break
169
+
170
+ centers, counts, labels, inertia = self._lloyd_polish(backend, X, centers, helper)
171
+ return centers, counts, labels, inertia, n_iter, n_steps
172
+
173
+ def fit(self, X, y=None, sample_weight=None):
174
+ if sparse.issparse(X):
175
+ raise NotImplementedError("sparse input is not supported in MiniBatchKMeans v1")
176
+ if sample_weight is not None:
177
+ raise NotImplementedError("sample_weight is not supported in MiniBatchKMeans v1")
178
+ backend = self._get_backend()
179
+ X_arr = backend.asarray(X, dtype=backend.float64)
180
+ check_2d_array(X_arr)
181
+ n_samples, n_features = X_arr.shape
182
+ n_clusters, n_init = self._validate_params(n_samples, n_features)
183
+ helper = self._helper()
184
+
185
+ rng = np.random.default_rng(self.random_state)
186
+ best = None
187
+ for _ in range(n_init):
188
+ run_rng = np.random.default_rng(rng.integers(0, np.iinfo(np.int32).max))
189
+ result = self._run_single(backend, X_arr, run_rng, n_clusters, helper)
190
+ if best is None or result[3] < best[3]:
191
+ best = result
192
+
193
+ (
194
+ self.cluster_centers_,
195
+ self.counts_,
196
+ self.labels_,
197
+ self.inertia_,
198
+ self.n_iter_,
199
+ self.n_steps_,
200
+ ) = best
201
+ self.n_features_in_ = int(n_features)
202
+ self._backend_name = backend.name
203
+ self._fitted = True
204
+ return self
205
+
206
+ def partial_fit(self, X, y=None, sample_weight=None):
207
+ if sparse.issparse(X):
208
+ raise NotImplementedError("sparse input is not supported in MiniBatchKMeans v1")
209
+ if sample_weight is not None:
210
+ raise NotImplementedError("sample_weight is not supported in MiniBatchKMeans v1")
211
+ backend = self._get_backend()
212
+ X_arr = backend.asarray(X, dtype=backend.float64)
213
+ check_2d_array(X_arr)
214
+ n_samples, n_features = X_arr.shape
215
+ needs_batch_based_init = (not self._fitted) and isinstance(self.init, str)
216
+ validation_samples = n_samples if needs_batch_based_init else max(n_samples, int(self.n_clusters))
217
+ n_clusters, _ = self._validate_params(validation_samples, n_features)
218
+ helper = self._helper()
219
+
220
+ if not self._fitted:
221
+ rng = np.random.default_rng(self.random_state)
222
+ self.cluster_centers_ = self._init_centers(backend, X_arr, rng, n_clusters, helper)
223
+ self.counts_ = backend.zeros((n_clusters,), dtype=backend.float64)
224
+ self.n_steps_ = 0
225
+ self.n_iter_ = 0
226
+ self.n_features_in_ = int(n_features)
227
+ self._backend_name = backend.name
228
+ self._fitted = True
229
+ elif X_arr.shape[1] != self.n_features_in_:
230
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
231
+
232
+ centers, counts, labels, inertia = self._single_batch_step(
233
+ backend, X_arr, self.cluster_centers_, self.counts_, helper
234
+ )
235
+ self.cluster_centers_ = centers
236
+ self.counts_ = counts
237
+ self.labels_ = labels
238
+ self.inertia_ = inertia
239
+ self.n_steps_ = int(self.n_steps_) + 1
240
+ self.n_iter_ = int(self.n_iter_) + 1
241
+ return self
242
+
243
+ def transform(self, X):
244
+ self._check_is_fitted()
245
+ if sparse.issparse(X):
246
+ raise NotImplementedError("sparse input is not supported in MiniBatchKMeans v1")
247
+ backend = self._get_backend()
248
+ X_arr = backend.asarray(X, dtype=backend.float64)
249
+ check_2d_array(X_arr)
250
+ if X_arr.shape[1] != self.n_features_in_:
251
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
252
+ helper = self._helper()
253
+ return backend.sqrt(helper._squared_distances(backend, X_arr, self.cluster_centers_))
254
+
255
+ def predict(self, X):
256
+ self._check_is_fitted()
257
+ if sparse.issparse(X):
258
+ raise NotImplementedError("sparse input is not supported in MiniBatchKMeans v1")
259
+ backend = self._get_backend()
260
+ X_arr = backend.asarray(X, dtype=backend.float64)
261
+ check_2d_array(X_arr)
262
+ if X_arr.shape[1] != self.n_features_in_:
263
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
264
+ helper = self._helper()
265
+ distances = helper._squared_distances(backend, X_arr, self.cluster_centers_)
266
+ return backend.argmin(distances, axis=1)
267
+
268
+ def fit_predict(self, X, y=None):
269
+ return self.fit(X, y=y).labels_
270
+
271
+ def score(self, X, y=None):
272
+ self._check_is_fitted()
273
+ if sparse.issparse(X):
274
+ raise NotImplementedError("sparse input is not supported in MiniBatchKMeans v1")
275
+ backend = self._get_backend()
276
+ X_arr = backend.asarray(X, dtype=backend.float64)
277
+ check_2d_array(X_arr)
278
+ if X_arr.shape[1] != self.n_features_in_:
279
+ raise ValueError(f"X has {X_arr.shape[1]} features, expected {self.n_features_in_}")
280
+ helper = self._helper()
281
+ distances = helper._squared_distances(backend, X_arr, self.cluster_centers_)
282
+ min_dist_sq = backend.min(distances, axis=1)
283
+ return -scalar_to_float(backend.sum(min_dist_sq))
284
+
285
+ def get_params(self, deep=True):
286
+ params = super().get_params(deep=deep)
287
+ params.update(
288
+ {
289
+ "n_clusters": self.n_clusters,
290
+ "init": self.init,
291
+ "n_init": self.n_init,
292
+ "batch_size": self.batch_size,
293
+ "max_iter": self.max_iter,
294
+ "max_no_improvement": self.max_no_improvement,
295
+ "tol": self.tol,
296
+ "random_state": self.random_state,
297
+ }
298
+ )
299
+ return params