statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,359 @@
1
+ """
2
+ Triton JIT kernel for Cox PH Efron backward gradient/Hessian.
3
+
4
+ Mirrors the algorithm in `_cox_efron_cuda.py` (CuPy RawKernel serial version).
5
+
6
+ Design:
7
+ - Single Triton program (grid=(1,)) executes the entire backward scan.
8
+ - P (feature dim) is constexpr, enabling loop unrolling for small p.
9
+ - Local scalar accumulators where possible; workspace tensor for p*p matrices.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ from typing import Any, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+
19
+
20
+ def _import_triton():
21
+ """Deferred Triton import."""
22
+ try:
23
+ import triton
24
+ import triton.language as tl
25
+ return triton, tl
26
+ except ImportError:
27
+ return None, None
28
+
29
+
30
+ _triton, _tl = _import_triton()
31
+ HAS_TRITON_EFRON: bool = False
32
+ HAS_TRITON_BRESLOW: bool = False
33
+
34
+ if _triton is not None and _tl is not None:
35
+ try:
36
+ import triton
37
+ import triton.language as tl
38
+
39
+ @triton.jit
40
+ def _efron_backward_scan_serial(
41
+ # Input tensors
42
+ X_ptr, # [n, p] float64
43
+ e_eta_ptr, # [n] float64
44
+ enter_ptr_ptr, # [nuft+1] int32
45
+ enter_ind_ptr, # [n_enter_total] int32
46
+ exit_ptr_ptr, # [nuft+1] int32
47
+ exit_ind_ptr, # [n_exit_total] int32
48
+ fail_ptr_ptr, # [nuft+1] int32
49
+ fail_ind_ptr, # [n_fail_total] int32
50
+ # Workspace (caller-allocated, zeroed)
51
+ ws_ptr, # [workspace_size] float64
52
+ # Output (caller-allocated, zeroed)
53
+ grad_ptr, # [p] float64
54
+ hess_ptr, # [p*p] float64
55
+ # Parameters
56
+ n,
57
+ p,
58
+ nuft,
59
+ # Compile-time constants
60
+ P: tl.constexpr,
61
+ ):
62
+ """Single-program serial Efron backward scan kernel."""
63
+
64
+ # Workspace layout (all offsets relative to ws_ptr):
65
+ WS_XP0 = 0
66
+ WS_XP1 = 1
67
+ WS_XP2 = 1 + P
68
+ WS_HESS = 1 + P + P * P
69
+ WS_XP1F = 1 + 2 * P * P
70
+ WS_XP2F = 1 + 2 * P * P + P
71
+ WS_SCRATCH = 1 + 3 * P * P + P
72
+ WS_SIZE = 1 + 3 * P * P + P + 1
73
+
74
+ # ws_ptr is already zeroed by caller.
75
+
76
+ # ---- Backward scan ----
77
+ for ii in range(nuft - 1, -1, -1):
78
+ # ---- Enter phase ----
79
+ e0 = tl.load(enter_ptr_ptr + ii)
80
+ e1 = tl.load(enter_ptr_ptr + ii + 1)
81
+ nt = e1 - e0
82
+
83
+ if nt > 0:
84
+ for t in range(0, nt, 1):
85
+ idx = tl.load(enter_ind_ptr + e0 + t)
86
+ row_off = idx * p
87
+ elx = tl.load(e_eta_ptr + idx)
88
+
89
+ # xp0 += elx
90
+ old = tl.load(ws_ptr + WS_XP0)
91
+ tl.store(ws_ptr + WS_XP0, old + elx)
92
+
93
+ # xp1[j] += elx * X[idx,j]
94
+ for j in range(0, P, 1):
95
+ if j < p:
96
+ xval = tl.load(X_ptr + row_off + j)
97
+ old = tl.load(ws_ptr + WS_XP1 + j)
98
+ tl.store(ws_ptr + WS_XP1 + j, old + elx * xval)
99
+
100
+ # xp2[j*P+k] += elx * X[idx,j] * X[idx,k]
101
+ for j in range(0, P, 1):
102
+ if j < p:
103
+ vj = tl.load(X_ptr + row_off + j)
104
+ for k in range(0, P, 1):
105
+ if k < p:
106
+ vk = tl.load(X_ptr + row_off + k)
107
+ old = tl.load(ws_ptr + WS_XP2 + j * P + k)
108
+ tl.store(ws_ptr + WS_XP2 + j * P + k, old + elx * vj * vk)
109
+
110
+ # ---- Fail phase ----
111
+ f0 = tl.load(fail_ptr_ptr + ii)
112
+ f1 = tl.load(fail_ptr_ptr + ii + 1)
113
+ m = f1 - f0
114
+
115
+ if m > 0:
116
+ # Zero xp1f and xp2f in workspace
117
+ for j in range(0, P, 1):
118
+ if j < p:
119
+ tl.store(ws_ptr + WS_XP1F + j, 0.0)
120
+ for j in range(0, P, 1):
121
+ if j < p:
122
+ for k in range(0, P, 1):
123
+ if k < p:
124
+ tl.store(ws_ptr + WS_XP2F + j * P + k, 0.0)
125
+
126
+ # Accumulate fail sums into xp1f, xp2f, xp0f
127
+ xp0f_acc = 0.0
128
+ for t in range(0, m, 1):
129
+ idx = tl.load(fail_ind_ptr + f0 + t)
130
+ row_off = idx * p
131
+ elx = tl.load(e_eta_ptr + idx)
132
+ xp0f_acc = xp0f_acc + elx
133
+
134
+ # grad[j] += X[idx,j]
135
+ for j in range(0, P, 1):
136
+ if j < p:
137
+ vj = tl.load(X_ptr + row_off + j)
138
+ old = tl.load(grad_ptr + j)
139
+ tl.store(grad_ptr + j, old + vj)
140
+
141
+ # xp1f[j] += elx * X[idx,j]
142
+ for j in range(0, P, 1):
143
+ if j < p:
144
+ vj = tl.load(X_ptr + row_off + j)
145
+ old = tl.load(ws_ptr + WS_XP1F + j)
146
+ tl.store(ws_ptr + WS_XP1F + j, old + elx * vj)
147
+
148
+ # xp2f[j*P+k] += elx * X[idx,j] * X[idx,k]
149
+ for j in range(0, P, 1):
150
+ if j < p:
151
+ vj = tl.load(X_ptr + row_off + j)
152
+ for k in range(0, P, 1):
153
+ if k < p:
154
+ vk = tl.load(X_ptr + row_off + k)
155
+ old = tl.load(ws_ptr + WS_XP2F + j * P + k)
156
+ tl.store(ws_ptr + WS_XP2F + j * P + k, old + elx * vj * vk)
157
+
158
+ # Efron correction (serial)
159
+ xp0v = tl.load(ws_ptr + WS_XP0)
160
+ sum_inv_c0 = 0.0
161
+ sum_J_c0 = 0.0
162
+ sum_aa = 0.0
163
+ sum_bb = 0.0
164
+ sum_ab = 0.0
165
+ for kk in range(0, m, 1):
166
+ Jk = (kk * 1.0) / (m * 1.0)
167
+ c0 = xp0v - Jk * xp0f_acc
168
+ if c0 < 1e-300:
169
+ c0 = 1e-300
170
+ ak = 1.0 / c0
171
+ bk = Jk * ak
172
+ sum_inv_c0 = sum_inv_c0 + ak
173
+ sum_J_c0 = sum_J_c0 + Jk / c0
174
+ sum_aa = sum_aa + ak * ak
175
+ sum_bb = sum_bb + bk * bk
176
+ sum_ab = sum_ab + ak * bk
177
+
178
+ # Apply to grad
179
+ for j in range(0, P, 1):
180
+ if j < p:
181
+ xp1j = tl.load(ws_ptr + WS_XP1 + j)
182
+ xp1fj = tl.load(ws_ptr + WS_XP1F + j)
183
+ old = tl.load(grad_ptr + j)
184
+ tl.store(grad_ptr + j, old - (xp1j * sum_inv_c0 - xp1fj * sum_J_c0))
185
+
186
+ # Apply to hess
187
+ for j in range(0, P, 1):
188
+ if j < p:
189
+ for k in range(0, P, 1):
190
+ if k < p:
191
+ xp2jk = tl.load(ws_ptr + WS_XP2 + j * P + k)
192
+ xp2fjk = tl.load(ws_ptr + WS_XP2F + j * P + k)
193
+ hess_val = xp2jk * sum_inv_c0 - xp2fjk * sum_J_c0
194
+
195
+ xp1j_v = tl.load(ws_ptr + WS_XP1 + j)
196
+ xp1k_v = tl.load(ws_ptr + WS_XP1 + k)
197
+ xp1fj_v = tl.load(ws_ptr + WS_XP1F + j)
198
+ xp1fk_v = tl.load(ws_ptr + WS_XP1F + k)
199
+ o11 = xp1j_v * xp1k_v
200
+ off_v = xp1fj_v * xp1fk_v
201
+ cross_v = xp1j_v * xp1fk_v + xp1fj_v * xp1k_v
202
+ hsub = sum_aa * o11 + sum_bb * off_v - sum_ab * cross_v
203
+ hess_val = hess_val - hsub
204
+
205
+ idx2 = j * P + k
206
+ old = tl.load(hess_ptr + idx2)
207
+ tl.store(hess_ptr + idx2, hess_val + old)
208
+
209
+ # ---- Exit phase ----
210
+ x0 = tl.load(exit_ptr_ptr + ii)
211
+ x1 = tl.load(exit_ptr_ptr + ii + 1)
212
+ nx = x1 - x0
213
+
214
+ if nx > 0:
215
+ for t in range(0, nx, 1):
216
+ idx = tl.load(exit_ind_ptr + x0 + t)
217
+ row_off = idx * p
218
+ elx = tl.load(e_eta_ptr + idx)
219
+
220
+ # xp0 -= elx
221
+ old = tl.load(ws_ptr + WS_XP0)
222
+ tl.store(ws_ptr + WS_XP0, old - elx)
223
+
224
+ # xp1[j] -= elx * X[idx,j]
225
+ for j in range(0, P, 1):
226
+ if j < p:
227
+ xval = tl.load(X_ptr + row_off + j)
228
+ old = tl.load(ws_ptr + WS_XP1 + j)
229
+ tl.store(ws_ptr + WS_XP1 + j, old - elx * xval)
230
+
231
+ # xp2 -= elx * X^T X
232
+ for j in range(0, P, 1):
233
+ if j < p:
234
+ vj = tl.load(X_ptr + row_off + j)
235
+ for k in range(0, P, 1):
236
+ if k < p:
237
+ vk = tl.load(X_ptr + row_off + k)
238
+ old = tl.load(ws_ptr + WS_XP2 + j * P + k)
239
+ tl.store(ws_ptr + WS_XP2 + j * P + k, old - elx * vj * vk)
240
+
241
+ HAS_TRITON_EFRON = True
242
+
243
+ except Exception:
244
+ HAS_TRITON_EFRON = False
245
+ _triton = None
246
+ _tl = None
247
+
248
+ # =====================================================================
249
+ # Breslow Hessian — PyTorch GPU path (cuBLAS matmul + vectorized ops)
250
+ # =====================================================================
251
+ # Originally attempted a Triton serial-scan kernel, but Triton 2.0 has a
252
+ # compiler bug producing non-deterministic wrong code for kernels with
253
+ # runtime-bounded loops (while/for with >= 3 iterations). The PyTorch
254
+ # approach is only marginally slower since each op is cuBLAS-optimized.
255
+ try:
256
+ from statgpu.survival._cox_breslow_triton_kernel import (
257
+ compute_breslow_grad_hess_triton,
258
+ _find_p_ce as _find_p_ce_breslow,
259
+ )
260
+ HAS_TRITON_BRESLOW = True
261
+ except Exception:
262
+ compute_breslow_grad_hess_triton = None
263
+ HAS_TRITON_BRESLOW = False
264
+
265
+
266
+ def _triton_available() -> bool:
267
+ return HAS_TRITON_EFRON
268
+
269
+
270
+ _SUPPORTED_P: Tuple[int, ...] = (8, 16, 32, 64, 128)
271
+
272
+
273
+ def _find_p_ce(p: int) -> Optional[int]:
274
+ for sp in _SUPPORTED_P:
275
+ if sp >= p:
276
+ return sp
277
+ return None
278
+
279
+
280
+ def compute_efron_grad_hess_triton(
281
+ X: Any,
282
+ beta: Any,
283
+ efron_pre: Any,
284
+ ) -> Optional[Tuple[Any, Any]]:
285
+ """Compute Efron gradient/Hessian via Triton serial kernel."""
286
+ if not HAS_TRITON_EFRON:
287
+ return None
288
+
289
+ import torch
290
+ from statgpu.survival._cox_efron_cuda import (
291
+ efron_indices_to_csr,
292
+ _pick_backward_launch_params,
293
+ )
294
+
295
+ if len(efron_pre) == 6:
296
+ _, uft_ix, risk_enter, risk_exit, nuft, _ = efron_pre
297
+ else:
298
+ _, uft_ix, risk_enter, risk_exit, nuft = efron_pre
299
+
300
+ p = int(X.shape[1])
301
+ p_ce = _find_p_ce(p)
302
+ if p_ce is None:
303
+ return None
304
+
305
+ if nuft == 0:
306
+ return (
307
+ torch.zeros(p, dtype=torch.float64, device=X.device),
308
+ torch.zeros((p, p), dtype=torch.float64, device=X.device),
309
+ )
310
+
311
+ n = int(X.shape[0])
312
+ device = X.device
313
+
314
+ # Build linear predictor
315
+ linpred = X @ beta
316
+ linpred = linpred - torch.max(linpred)
317
+ e_eta = torch.exp(linpred)
318
+
319
+ # Build CSR
320
+ enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind = efron_indices_to_csr(
321
+ uft_ix, risk_enter, risk_exit, nuft
322
+ )
323
+
324
+ enter_ptr_t = torch.as_tensor(enter_ptr, dtype=torch.int32, device=device)
325
+ enter_ind_t = torch.as_tensor(enter_ind, dtype=torch.int32, device=device)
326
+ exit_ptr_t = torch.as_tensor(exit_ptr, dtype=torch.int32, device=device)
327
+ exit_ind_t = torch.as_tensor(exit_ind, dtype=torch.int32, device=device)
328
+ fail_ptr_t = torch.as_tensor(fail_ptr, dtype=torch.int32, device=device)
329
+ fail_ind_t = torch.as_tensor(fail_ind, dtype=torch.int32, device=device)
330
+
331
+ seq_thresh, _ = _pick_backward_launch_params(p, nuft, n)
332
+
333
+ # Workspace: WS_XP0(1) + WS_XP1(P) + WS_XP2(P*P) + WS_HESS(P*P) +
334
+ # WS_XP1F(P) + WS_XP2F(P*P) + WS_SCRATCH(1)
335
+ ws_size = 1 + 3 * p_ce + 3 * p_ce * p_ce + 1
336
+ ws = torch.zeros(ws_size, dtype=torch.float64, device=device)
337
+ grad_out = torch.zeros(p, dtype=torch.float64, device=device)
338
+ # Allocate hess_out with padded stride (p_ce) to match Triton kernel indexing
339
+ hess_out = torch.zeros(p_ce * p_ce, dtype=torch.float64, device=device)
340
+
341
+ try:
342
+ _efron_backward_scan_serial[(1,)](
343
+ X, e_eta,
344
+ enter_ptr_t, enter_ind_t,
345
+ exit_ptr_t, exit_ind_t,
346
+ fail_ptr_t, fail_ind_t,
347
+ ws, grad_out, hess_out,
348
+ n, p, nuft,
349
+ P=p_ce,
350
+ )
351
+ torch.cuda.synchronize()
352
+ except Exception:
353
+ return None
354
+
355
+ return grad_out, -hess_out.view(p_ce, p_ce)[:p, :p]
356
+
357
+
358
+ # compute_breslow_grad_hess_triton and _find_p_ce are imported from
359
+ # _cox_breslow_triton_kernel.py above (in the try/except block).
@@ -0,0 +1,29 @@
1
+ """Unsupervised learning estimators."""
2
+
3
+ from ._pca import PCA
4
+ from ._kmeans import KMeans
5
+ from ._dbscan import DBSCAN
6
+ from ._gmm import GaussianMixture
7
+ from ._nmf import NMF
8
+ from ._agglomerative import AgglomerativeClustering
9
+ from ._truncated_svd import TruncatedSVD
10
+ from ._minibatch_kmeans import MiniBatchKMeans
11
+ from ._incremental_pca import IncrementalPCA
12
+ from ._minibatch_nmf import MiniBatchNMF
13
+ from ._umap import UMAP
14
+ from ._tsne import TSNE
15
+
16
+ __all__ = [
17
+ "PCA",
18
+ "KMeans",
19
+ "DBSCAN",
20
+ "GaussianMixture",
21
+ "NMF",
22
+ "AgglomerativeClustering",
23
+ "TruncatedSVD",
24
+ "MiniBatchKMeans",
25
+ "IncrementalPCA",
26
+ "MiniBatchNMF",
27
+ "UMAP",
28
+ "TSNE",
29
+ ]
@@ -0,0 +1,307 @@
1
+ """Agglomerative clustering."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import warnings
7
+ from typing import Optional, Union
8
+
9
+ import numpy as np
10
+ from scipy.cluster.hierarchy import fcluster, linkage
11
+
12
+ from statgpu._base import BaseEstimator
13
+ from statgpu._config import Device
14
+ from statgpu.unsupervised._utils import check_2d_array, reject_sparse, squared_euclidean_distances
15
+
16
+
17
+ DEFAULT_GPU_DISTANCE_LIMIT_BYTES = 1 << 30
18
+
19
+
20
+ def _gpu_distance_limit_bytes() -> int:
21
+ value = os.environ.get("STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES")
22
+ if value is None:
23
+ return DEFAULT_GPU_DISTANCE_LIMIT_BYTES
24
+ try:
25
+ return int(value)
26
+ except ValueError:
27
+ warnings.warn(
28
+ "Invalid STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES value; "
29
+ f"using default {DEFAULT_GPU_DISTANCE_LIMIT_BYTES} bytes.",
30
+ RuntimeWarning,
31
+ stacklevel=2,
32
+ )
33
+ return DEFAULT_GPU_DISTANCE_LIMIT_BYTES
34
+
35
+
36
+ class AgglomerativeClustering(BaseEstimator):
37
+ """Exact dense agglomerative clustering."""
38
+
39
+ _GPU_DISTANCE_LIMIT_BYTES = _gpu_distance_limit_bytes()
40
+
41
+ def __init__(
42
+ self,
43
+ n_clusters: int = 2,
44
+ linkage: str = "single",
45
+ metric: str = "euclidean",
46
+ device: Union[str, Device] = Device.AUTO,
47
+ n_jobs: Optional[int] = None,
48
+ ):
49
+ super().__init__(device=device, n_jobs=n_jobs)
50
+ self.n_clusters = n_clusters
51
+ self.linkage = linkage
52
+ self.metric = metric
53
+
54
+ def _validate_params(self, n_samples: int):
55
+ if not isinstance(self.n_clusters, (int, np.integer)) or int(self.n_clusters) < 1:
56
+ raise ValueError("n_clusters must be a positive integer")
57
+ if int(self.n_clusters) > n_samples:
58
+ raise ValueError("n_clusters must be less than or equal to n_samples")
59
+ if self.linkage not in ("single", "complete", "average", "ward"):
60
+ raise ValueError("linkage must be one of: 'single', 'complete', 'average', 'ward'")
61
+ if self.metric != "euclidean":
62
+ raise NotImplementedError("AgglomerativeClustering only supports metric='euclidean'")
63
+
64
+ def _use_gpu_path(self) -> bool:
65
+ return self.device in (Device.CUDA, Device.TORCH)
66
+
67
+ def _check_gpu_memory(self, n_samples: int):
68
+ required = int(n_samples) * int(n_samples) * 8
69
+ if required > self._GPU_DISTANCE_LIMIT_BYTES:
70
+ limit_mb = self._GPU_DISTANCE_LIMIT_BYTES / (1024**2)
71
+ required_mb = required / (1024**2)
72
+ raise MemoryError(
73
+ "AgglomerativeClustering GPU exact path requires a dense "
74
+ f"distance matrix of about {required_mb:.1f} MiB, exceeding "
75
+ f"the configured limit {limit_mb:.1f} MiB. Use device='cpu' "
76
+ "or raise STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES explicitly."
77
+ )
78
+
79
+ @staticmethod
80
+ def _labels_from_children(n_samples: int, n_clusters: int, children: np.ndarray) -> np.ndarray:
81
+ clusters = {i: [i] for i in range(n_samples)}
82
+ next_id = n_samples
83
+ merges_to_apply = max(0, n_samples - int(n_clusters))
84
+ for left, right in children[:merges_to_apply]:
85
+ members = clusters.pop(int(left)) + clusters.pop(int(right))
86
+ clusters[next_id] = members
87
+ next_id += 1
88
+
89
+ labels = np.empty(n_samples, dtype=np.int64)
90
+ for label, members in enumerate(clusters.values()):
91
+ labels[np.asarray(members, dtype=np.int64)] = label
92
+ return labels
93
+
94
+ @staticmethod
95
+ def _single_linkage_from_mst(
96
+ n_samples: int,
97
+ edge_parents: np.ndarray,
98
+ edge_children: np.ndarray,
99
+ edge_weights: np.ndarray,
100
+ ):
101
+ order = np.argsort(edge_weights, kind="mergesort")
102
+ uf_parent = list(range(n_samples))
103
+ cluster_ids = list(range(n_samples))
104
+ children = np.empty((n_samples - 1, 2), dtype=np.int64)
105
+ distances = np.empty(n_samples - 1, dtype=np.float64)
106
+
107
+ def find(idx: int) -> int:
108
+ while uf_parent[idx] != idx:
109
+ uf_parent[idx] = uf_parent[uf_parent[idx]]
110
+ idx = uf_parent[idx]
111
+ return idx
112
+
113
+ merge_step = 0
114
+ for edge_idx in order:
115
+ left_root = find(int(edge_parents[edge_idx]))
116
+ right_root = find(int(edge_children[edge_idx]))
117
+ if left_root == right_root:
118
+ continue
119
+ children[merge_step] = (cluster_ids[left_root], cluster_ids[right_root])
120
+ distances[merge_step] = float(edge_weights[edge_idx])
121
+ uf_parent[right_root] = left_root
122
+ cluster_ids[left_root] = n_samples + merge_step
123
+ merge_step += 1
124
+ if merge_step == n_samples - 1:
125
+ break
126
+
127
+ return children, distances
128
+
129
+ def _fit_gpu_single(self, backend, X_arr, n_samples: int):
130
+ D = backend.sqrt(squared_euclidean_distances(backend, X_arr))
131
+ inf = float("inf")
132
+ indices = backend.arange(n_samples, dtype=backend.int64)
133
+ D[indices, indices] = inf
134
+
135
+ selected = backend.zeros(n_samples, dtype=backend.bool)
136
+ selected[0] = True
137
+ min_dist = backend.copy(D[0, :])
138
+ min_dist[0] = inf
139
+ nearest_parent = backend.zeros(n_samples, dtype=backend.int64)
140
+
141
+ edge_parents = np.empty(n_samples - 1, dtype=np.int64)
142
+ edge_children = np.empty(n_samples - 1, dtype=np.int64)
143
+ edge_weights = np.empty(n_samples - 1, dtype=np.float64)
144
+
145
+ for step in range(n_samples - 1):
146
+ child = int(float(backend.argmin(min_dist)))
147
+ edge_children[step] = child
148
+ edge_parents[step] = int(float(nearest_parent[child]))
149
+ edge_weights[step] = float(min_dist[child])
150
+
151
+ selected[child] = True
152
+ candidate = D[child, :]
153
+ update_mask = (candidate < min_dist) & (~selected)
154
+ nearest_parent[update_mask] = child
155
+ min_dist = backend.where(update_mask, candidate, min_dist)
156
+ min_dist[child] = inf
157
+
158
+ return self._single_linkage_from_mst(n_samples, edge_parents, edge_children, edge_weights)
159
+
160
+ def _fit_gpu(self, X):
161
+ backend = self._get_backend()
162
+ X_arr = self._to_array(X, backend=backend.name)
163
+ X_arr = backend.asarray(X_arr, dtype=backend.float64)
164
+ check_2d_array(X_arr)
165
+ n_samples, n_features = X_arr.shape
166
+ self._validate_params(n_samples)
167
+ self._check_gpu_memory(n_samples)
168
+
169
+ if n_samples == 1:
170
+ self.labels_ = np.zeros(1, dtype=np.int64)
171
+ self.children_ = np.empty((0, 2), dtype=np.int64)
172
+ self.distances_ = np.empty((0,), dtype=np.float64)
173
+ self.n_features_in_ = int(n_features)
174
+ self._backend_name = backend.name
175
+ self._fitted = True
176
+ return self
177
+
178
+ if self.linkage == "single" and backend.name in ("cupy", "torch"):
179
+ children, distances = self._fit_gpu_single(backend, X_arr, n_samples)
180
+ self.children_ = children
181
+ self.distances_ = distances
182
+ self.labels_ = self._labels_from_children(n_samples, int(self.n_clusters), children)
183
+ self.n_features_in_ = int(n_features)
184
+ self._backend_name = backend.name
185
+ self._fitted = True
186
+ return self
187
+
188
+ D = squared_euclidean_distances(backend, X_arr)
189
+ if self.linkage != "ward":
190
+ D = backend.sqrt(D)
191
+ inf = float("inf")
192
+ indices = backend.arange(n_samples, dtype=backend.int64)
193
+ D[indices, indices] = inf
194
+
195
+ children = np.empty((n_samples - 1, 2), dtype=np.int64)
196
+ distances = np.empty(n_samples - 1, dtype=np.float64)
197
+ cluster_ids = list(range(n_samples))
198
+ cluster_sizes = [1.0] * n_samples
199
+ cluster_sizes_backend = (
200
+ backend.asarray(cluster_sizes, dtype=backend.float64) if self.linkage == "ward" else None
201
+ )
202
+
203
+ for step in range(n_samples - 1):
204
+ flat_idx = int(float(backend.argmin(D)))
205
+ a = flat_idx // n_samples
206
+ b = flat_idx % n_samples
207
+ if b < a:
208
+ a, b = b, a
209
+
210
+ merge_value = float(D[a, b])
211
+ children[step] = (cluster_ids[a], cluster_ids[b])
212
+ distances[step] = np.sqrt(max(merge_value, 0.0)) if self.linkage == "ward" else merge_value
213
+
214
+ da = D[a, :]
215
+ db = D[b, :]
216
+ size_a = cluster_sizes[a]
217
+ size_b = cluster_sizes[b]
218
+
219
+ if self.linkage == "single":
220
+ updated = backend.minimum(da, db)
221
+ elif self.linkage == "complete":
222
+ if backend.name in ("cupy", "torch"):
223
+ backend.xp.maximum(da, db, out=da)
224
+ updated = da
225
+ else:
226
+ updated = backend.maximum(da, db)
227
+ elif self.linkage == "average":
228
+ if backend.name in ("cupy", "torch"):
229
+ da *= size_a
230
+ da += size_b * db
231
+ da /= size_a + size_b
232
+ updated = da
233
+ else:
234
+ updated = (size_a * da + size_b * db) / (size_a + size_b)
235
+ else:
236
+ total = size_a + size_b + cluster_sizes_backend
237
+ updated = (
238
+ ((cluster_sizes_backend + size_a) / total) * da
239
+ + ((cluster_sizes_backend + size_b) / total) * db
240
+ - (cluster_sizes_backend / total) * merge_value
241
+ )
242
+ updated = backend.maximum(updated, 0.0)
243
+
244
+ D[a, :] = updated
245
+ D[:, a] = updated
246
+ cluster_ids[a] = n_samples + step
247
+ cluster_sizes[a] += cluster_sizes[b]
248
+ cluster_sizes[b] = 0.0
249
+ if cluster_sizes_backend is not None:
250
+ cluster_sizes_backend[a] = cluster_sizes[a]
251
+ cluster_sizes_backend[b] = 0.0
252
+ D[b, :] = inf
253
+ D[:, b] = inf
254
+ D[a, a] = inf
255
+
256
+ self.children_ = children
257
+ self.distances_ = distances
258
+ self.labels_ = self._labels_from_children(n_samples, int(self.n_clusters), children)
259
+ self.n_features_in_ = int(n_features)
260
+ self._backend_name = backend.name
261
+ self._fitted = True
262
+ return self
263
+
264
+ def fit(self, X, y=None):
265
+ reject_sparse(X, "AgglomerativeClustering")
266
+ if self._use_gpu_path():
267
+ return self._fit_gpu(X)
268
+
269
+ X_arr = np.asarray(X, dtype=np.float64)
270
+ check_2d_array(X_arr)
271
+ n_samples, n_features = X_arr.shape
272
+ self._validate_params(n_samples)
273
+
274
+ if n_samples == 1:
275
+ children = np.empty((0, 2), dtype=np.int64)
276
+ distances = np.empty((0,), dtype=np.float64)
277
+ labels = np.zeros(1, dtype=np.int64)
278
+ else:
279
+ Z = linkage(X_arr, method=self.linkage, metric="euclidean")
280
+ children = Z[:, :2].astype(np.int64, copy=False)
281
+ distances = Z[:, 2].astype(np.float64, copy=False)
282
+ labels = fcluster(Z, t=int(self.n_clusters), criterion="maxclust").astype(np.int64) - 1
283
+
284
+ self.labels_ = labels
285
+ self.children_ = children
286
+ self.distances_ = distances
287
+ self.n_features_in_ = int(n_features)
288
+ self._backend_name = "numpy"
289
+ self._fitted = True
290
+ return self
291
+
292
+ def fit_predict(self, X, y=None):
293
+ return self.fit(X, y=y).labels_
294
+
295
+ def predict(self, X):
296
+ raise NotImplementedError("AgglomerativeClustering does not support predict for unseen samples")
297
+
298
+ def get_params(self, deep=True):
299
+ params = super().get_params(deep=deep)
300
+ params.update(
301
+ {
302
+ "n_clusters": self.n_clusters,
303
+ "linkage": self.linkage,
304
+ "metric": self.metric,
305
+ }
306
+ )
307
+ return params