statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Triton JIT kernel for Cox PH Efron backward gradient/Hessian.
|
|
3
|
+
|
|
4
|
+
Mirrors the algorithm in `_cox_efron_cuda.py` (CuPy RawKernel serial version).
|
|
5
|
+
|
|
6
|
+
Design:
|
|
7
|
+
- Single Triton program (grid=(1,)) executes the entire backward scan.
|
|
8
|
+
- P (feature dim) is constexpr, enabling loop unrolling for small p.
|
|
9
|
+
- Local scalar accumulators where possible; workspace tensor for p*p matrices.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _import_triton():
|
|
21
|
+
"""Deferred Triton import."""
|
|
22
|
+
try:
|
|
23
|
+
import triton
|
|
24
|
+
import triton.language as tl
|
|
25
|
+
return triton, tl
|
|
26
|
+
except ImportError:
|
|
27
|
+
return None, None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_triton, _tl = _import_triton()
|
|
31
|
+
HAS_TRITON_EFRON: bool = False
|
|
32
|
+
HAS_TRITON_BRESLOW: bool = False
|
|
33
|
+
|
|
34
|
+
if _triton is not None and _tl is not None:
|
|
35
|
+
try:
|
|
36
|
+
import triton
|
|
37
|
+
import triton.language as tl
|
|
38
|
+
|
|
39
|
+
@triton.jit
|
|
40
|
+
def _efron_backward_scan_serial(
|
|
41
|
+
# Input tensors
|
|
42
|
+
X_ptr, # [n, p] float64
|
|
43
|
+
e_eta_ptr, # [n] float64
|
|
44
|
+
enter_ptr_ptr, # [nuft+1] int32
|
|
45
|
+
enter_ind_ptr, # [n_enter_total] int32
|
|
46
|
+
exit_ptr_ptr, # [nuft+1] int32
|
|
47
|
+
exit_ind_ptr, # [n_exit_total] int32
|
|
48
|
+
fail_ptr_ptr, # [nuft+1] int32
|
|
49
|
+
fail_ind_ptr, # [n_fail_total] int32
|
|
50
|
+
# Workspace (caller-allocated, zeroed)
|
|
51
|
+
ws_ptr, # [workspace_size] float64
|
|
52
|
+
# Output (caller-allocated, zeroed)
|
|
53
|
+
grad_ptr, # [p] float64
|
|
54
|
+
hess_ptr, # [p*p] float64
|
|
55
|
+
# Parameters
|
|
56
|
+
n,
|
|
57
|
+
p,
|
|
58
|
+
nuft,
|
|
59
|
+
# Compile-time constants
|
|
60
|
+
P: tl.constexpr,
|
|
61
|
+
):
|
|
62
|
+
"""Single-program serial Efron backward scan kernel."""
|
|
63
|
+
|
|
64
|
+
# Workspace layout (all offsets relative to ws_ptr):
|
|
65
|
+
WS_XP0 = 0
|
|
66
|
+
WS_XP1 = 1
|
|
67
|
+
WS_XP2 = 1 + P
|
|
68
|
+
WS_HESS = 1 + P + P * P
|
|
69
|
+
WS_XP1F = 1 + 2 * P * P
|
|
70
|
+
WS_XP2F = 1 + 2 * P * P + P
|
|
71
|
+
WS_SCRATCH = 1 + 3 * P * P + P
|
|
72
|
+
WS_SIZE = 1 + 3 * P * P + P + 1
|
|
73
|
+
|
|
74
|
+
# ws_ptr is already zeroed by caller.
|
|
75
|
+
|
|
76
|
+
# ---- Backward scan ----
|
|
77
|
+
for ii in range(nuft - 1, -1, -1):
|
|
78
|
+
# ---- Enter phase ----
|
|
79
|
+
e0 = tl.load(enter_ptr_ptr + ii)
|
|
80
|
+
e1 = tl.load(enter_ptr_ptr + ii + 1)
|
|
81
|
+
nt = e1 - e0
|
|
82
|
+
|
|
83
|
+
if nt > 0:
|
|
84
|
+
for t in range(0, nt, 1):
|
|
85
|
+
idx = tl.load(enter_ind_ptr + e0 + t)
|
|
86
|
+
row_off = idx * p
|
|
87
|
+
elx = tl.load(e_eta_ptr + idx)
|
|
88
|
+
|
|
89
|
+
# xp0 += elx
|
|
90
|
+
old = tl.load(ws_ptr + WS_XP0)
|
|
91
|
+
tl.store(ws_ptr + WS_XP0, old + elx)
|
|
92
|
+
|
|
93
|
+
# xp1[j] += elx * X[idx,j]
|
|
94
|
+
for j in range(0, P, 1):
|
|
95
|
+
if j < p:
|
|
96
|
+
xval = tl.load(X_ptr + row_off + j)
|
|
97
|
+
old = tl.load(ws_ptr + WS_XP1 + j)
|
|
98
|
+
tl.store(ws_ptr + WS_XP1 + j, old + elx * xval)
|
|
99
|
+
|
|
100
|
+
# xp2[j*P+k] += elx * X[idx,j] * X[idx,k]
|
|
101
|
+
for j in range(0, P, 1):
|
|
102
|
+
if j < p:
|
|
103
|
+
vj = tl.load(X_ptr + row_off + j)
|
|
104
|
+
for k in range(0, P, 1):
|
|
105
|
+
if k < p:
|
|
106
|
+
vk = tl.load(X_ptr + row_off + k)
|
|
107
|
+
old = tl.load(ws_ptr + WS_XP2 + j * P + k)
|
|
108
|
+
tl.store(ws_ptr + WS_XP2 + j * P + k, old + elx * vj * vk)
|
|
109
|
+
|
|
110
|
+
# ---- Fail phase ----
|
|
111
|
+
f0 = tl.load(fail_ptr_ptr + ii)
|
|
112
|
+
f1 = tl.load(fail_ptr_ptr + ii + 1)
|
|
113
|
+
m = f1 - f0
|
|
114
|
+
|
|
115
|
+
if m > 0:
|
|
116
|
+
# Zero xp1f and xp2f in workspace
|
|
117
|
+
for j in range(0, P, 1):
|
|
118
|
+
if j < p:
|
|
119
|
+
tl.store(ws_ptr + WS_XP1F + j, 0.0)
|
|
120
|
+
for j in range(0, P, 1):
|
|
121
|
+
if j < p:
|
|
122
|
+
for k in range(0, P, 1):
|
|
123
|
+
if k < p:
|
|
124
|
+
tl.store(ws_ptr + WS_XP2F + j * P + k, 0.0)
|
|
125
|
+
|
|
126
|
+
# Accumulate fail sums into xp1f, xp2f, xp0f
|
|
127
|
+
xp0f_acc = 0.0
|
|
128
|
+
for t in range(0, m, 1):
|
|
129
|
+
idx = tl.load(fail_ind_ptr + f0 + t)
|
|
130
|
+
row_off = idx * p
|
|
131
|
+
elx = tl.load(e_eta_ptr + idx)
|
|
132
|
+
xp0f_acc = xp0f_acc + elx
|
|
133
|
+
|
|
134
|
+
# grad[j] += X[idx,j]
|
|
135
|
+
for j in range(0, P, 1):
|
|
136
|
+
if j < p:
|
|
137
|
+
vj = tl.load(X_ptr + row_off + j)
|
|
138
|
+
old = tl.load(grad_ptr + j)
|
|
139
|
+
tl.store(grad_ptr + j, old + vj)
|
|
140
|
+
|
|
141
|
+
# xp1f[j] += elx * X[idx,j]
|
|
142
|
+
for j in range(0, P, 1):
|
|
143
|
+
if j < p:
|
|
144
|
+
vj = tl.load(X_ptr + row_off + j)
|
|
145
|
+
old = tl.load(ws_ptr + WS_XP1F + j)
|
|
146
|
+
tl.store(ws_ptr + WS_XP1F + j, old + elx * vj)
|
|
147
|
+
|
|
148
|
+
# xp2f[j*P+k] += elx * X[idx,j] * X[idx,k]
|
|
149
|
+
for j in range(0, P, 1):
|
|
150
|
+
if j < p:
|
|
151
|
+
vj = tl.load(X_ptr + row_off + j)
|
|
152
|
+
for k in range(0, P, 1):
|
|
153
|
+
if k < p:
|
|
154
|
+
vk = tl.load(X_ptr + row_off + k)
|
|
155
|
+
old = tl.load(ws_ptr + WS_XP2F + j * P + k)
|
|
156
|
+
tl.store(ws_ptr + WS_XP2F + j * P + k, old + elx * vj * vk)
|
|
157
|
+
|
|
158
|
+
# Efron correction (serial)
|
|
159
|
+
xp0v = tl.load(ws_ptr + WS_XP0)
|
|
160
|
+
sum_inv_c0 = 0.0
|
|
161
|
+
sum_J_c0 = 0.0
|
|
162
|
+
sum_aa = 0.0
|
|
163
|
+
sum_bb = 0.0
|
|
164
|
+
sum_ab = 0.0
|
|
165
|
+
for kk in range(0, m, 1):
|
|
166
|
+
Jk = (kk * 1.0) / (m * 1.0)
|
|
167
|
+
c0 = xp0v - Jk * xp0f_acc
|
|
168
|
+
if c0 < 1e-300:
|
|
169
|
+
c0 = 1e-300
|
|
170
|
+
ak = 1.0 / c0
|
|
171
|
+
bk = Jk * ak
|
|
172
|
+
sum_inv_c0 = sum_inv_c0 + ak
|
|
173
|
+
sum_J_c0 = sum_J_c0 + Jk / c0
|
|
174
|
+
sum_aa = sum_aa + ak * ak
|
|
175
|
+
sum_bb = sum_bb + bk * bk
|
|
176
|
+
sum_ab = sum_ab + ak * bk
|
|
177
|
+
|
|
178
|
+
# Apply to grad
|
|
179
|
+
for j in range(0, P, 1):
|
|
180
|
+
if j < p:
|
|
181
|
+
xp1j = tl.load(ws_ptr + WS_XP1 + j)
|
|
182
|
+
xp1fj = tl.load(ws_ptr + WS_XP1F + j)
|
|
183
|
+
old = tl.load(grad_ptr + j)
|
|
184
|
+
tl.store(grad_ptr + j, old - (xp1j * sum_inv_c0 - xp1fj * sum_J_c0))
|
|
185
|
+
|
|
186
|
+
# Apply to hess
|
|
187
|
+
for j in range(0, P, 1):
|
|
188
|
+
if j < p:
|
|
189
|
+
for k in range(0, P, 1):
|
|
190
|
+
if k < p:
|
|
191
|
+
xp2jk = tl.load(ws_ptr + WS_XP2 + j * P + k)
|
|
192
|
+
xp2fjk = tl.load(ws_ptr + WS_XP2F + j * P + k)
|
|
193
|
+
hess_val = xp2jk * sum_inv_c0 - xp2fjk * sum_J_c0
|
|
194
|
+
|
|
195
|
+
xp1j_v = tl.load(ws_ptr + WS_XP1 + j)
|
|
196
|
+
xp1k_v = tl.load(ws_ptr + WS_XP1 + k)
|
|
197
|
+
xp1fj_v = tl.load(ws_ptr + WS_XP1F + j)
|
|
198
|
+
xp1fk_v = tl.load(ws_ptr + WS_XP1F + k)
|
|
199
|
+
o11 = xp1j_v * xp1k_v
|
|
200
|
+
off_v = xp1fj_v * xp1fk_v
|
|
201
|
+
cross_v = xp1j_v * xp1fk_v + xp1fj_v * xp1k_v
|
|
202
|
+
hsub = sum_aa * o11 + sum_bb * off_v - sum_ab * cross_v
|
|
203
|
+
hess_val = hess_val - hsub
|
|
204
|
+
|
|
205
|
+
idx2 = j * P + k
|
|
206
|
+
old = tl.load(hess_ptr + idx2)
|
|
207
|
+
tl.store(hess_ptr + idx2, hess_val + old)
|
|
208
|
+
|
|
209
|
+
# ---- Exit phase ----
|
|
210
|
+
x0 = tl.load(exit_ptr_ptr + ii)
|
|
211
|
+
x1 = tl.load(exit_ptr_ptr + ii + 1)
|
|
212
|
+
nx = x1 - x0
|
|
213
|
+
|
|
214
|
+
if nx > 0:
|
|
215
|
+
for t in range(0, nx, 1):
|
|
216
|
+
idx = tl.load(exit_ind_ptr + x0 + t)
|
|
217
|
+
row_off = idx * p
|
|
218
|
+
elx = tl.load(e_eta_ptr + idx)
|
|
219
|
+
|
|
220
|
+
# xp0 -= elx
|
|
221
|
+
old = tl.load(ws_ptr + WS_XP0)
|
|
222
|
+
tl.store(ws_ptr + WS_XP0, old - elx)
|
|
223
|
+
|
|
224
|
+
# xp1[j] -= elx * X[idx,j]
|
|
225
|
+
for j in range(0, P, 1):
|
|
226
|
+
if j < p:
|
|
227
|
+
xval = tl.load(X_ptr + row_off + j)
|
|
228
|
+
old = tl.load(ws_ptr + WS_XP1 + j)
|
|
229
|
+
tl.store(ws_ptr + WS_XP1 + j, old - elx * xval)
|
|
230
|
+
|
|
231
|
+
# xp2 -= elx * X^T X
|
|
232
|
+
for j in range(0, P, 1):
|
|
233
|
+
if j < p:
|
|
234
|
+
vj = tl.load(X_ptr + row_off + j)
|
|
235
|
+
for k in range(0, P, 1):
|
|
236
|
+
if k < p:
|
|
237
|
+
vk = tl.load(X_ptr + row_off + k)
|
|
238
|
+
old = tl.load(ws_ptr + WS_XP2 + j * P + k)
|
|
239
|
+
tl.store(ws_ptr + WS_XP2 + j * P + k, old - elx * vj * vk)
|
|
240
|
+
|
|
241
|
+
HAS_TRITON_EFRON = True
|
|
242
|
+
|
|
243
|
+
except Exception:
|
|
244
|
+
HAS_TRITON_EFRON = False
|
|
245
|
+
_triton = None
|
|
246
|
+
_tl = None
|
|
247
|
+
|
|
248
|
+
# =====================================================================
|
|
249
|
+
# Breslow Hessian — PyTorch GPU path (cuBLAS matmul + vectorized ops)
|
|
250
|
+
# =====================================================================
|
|
251
|
+
# Originally attempted a Triton serial-scan kernel, but Triton 2.0 has a
|
|
252
|
+
# compiler bug producing non-deterministic wrong code for kernels with
|
|
253
|
+
# runtime-bounded loops (while/for with >= 3 iterations). The PyTorch
|
|
254
|
+
# approach is only marginally slower since each op is cuBLAS-optimized.
|
|
255
|
+
try:
|
|
256
|
+
from statgpu.survival._cox_breslow_triton_kernel import (
|
|
257
|
+
compute_breslow_grad_hess_triton,
|
|
258
|
+
_find_p_ce as _find_p_ce_breslow,
|
|
259
|
+
)
|
|
260
|
+
HAS_TRITON_BRESLOW = True
|
|
261
|
+
except Exception:
|
|
262
|
+
compute_breslow_grad_hess_triton = None
|
|
263
|
+
HAS_TRITON_BRESLOW = False
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _triton_available() -> bool:
|
|
267
|
+
return HAS_TRITON_EFRON
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
_SUPPORTED_P: Tuple[int, ...] = (8, 16, 32, 64, 128)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _find_p_ce(p: int) -> Optional[int]:
|
|
274
|
+
for sp in _SUPPORTED_P:
|
|
275
|
+
if sp >= p:
|
|
276
|
+
return sp
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def compute_efron_grad_hess_triton(
|
|
281
|
+
X: Any,
|
|
282
|
+
beta: Any,
|
|
283
|
+
efron_pre: Any,
|
|
284
|
+
) -> Optional[Tuple[Any, Any]]:
|
|
285
|
+
"""Compute Efron gradient/Hessian via Triton serial kernel."""
|
|
286
|
+
if not HAS_TRITON_EFRON:
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
import torch
|
|
290
|
+
from statgpu.survival._cox_efron_cuda import (
|
|
291
|
+
efron_indices_to_csr,
|
|
292
|
+
_pick_backward_launch_params,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if len(efron_pre) == 6:
|
|
296
|
+
_, uft_ix, risk_enter, risk_exit, nuft, _ = efron_pre
|
|
297
|
+
else:
|
|
298
|
+
_, uft_ix, risk_enter, risk_exit, nuft = efron_pre
|
|
299
|
+
|
|
300
|
+
p = int(X.shape[1])
|
|
301
|
+
p_ce = _find_p_ce(p)
|
|
302
|
+
if p_ce is None:
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
if nuft == 0:
|
|
306
|
+
return (
|
|
307
|
+
torch.zeros(p, dtype=torch.float64, device=X.device),
|
|
308
|
+
torch.zeros((p, p), dtype=torch.float64, device=X.device),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
n = int(X.shape[0])
|
|
312
|
+
device = X.device
|
|
313
|
+
|
|
314
|
+
# Build linear predictor
|
|
315
|
+
linpred = X @ beta
|
|
316
|
+
linpred = linpred - torch.max(linpred)
|
|
317
|
+
e_eta = torch.exp(linpred)
|
|
318
|
+
|
|
319
|
+
# Build CSR
|
|
320
|
+
enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind = efron_indices_to_csr(
|
|
321
|
+
uft_ix, risk_enter, risk_exit, nuft
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
enter_ptr_t = torch.as_tensor(enter_ptr, dtype=torch.int32, device=device)
|
|
325
|
+
enter_ind_t = torch.as_tensor(enter_ind, dtype=torch.int32, device=device)
|
|
326
|
+
exit_ptr_t = torch.as_tensor(exit_ptr, dtype=torch.int32, device=device)
|
|
327
|
+
exit_ind_t = torch.as_tensor(exit_ind, dtype=torch.int32, device=device)
|
|
328
|
+
fail_ptr_t = torch.as_tensor(fail_ptr, dtype=torch.int32, device=device)
|
|
329
|
+
fail_ind_t = torch.as_tensor(fail_ind, dtype=torch.int32, device=device)
|
|
330
|
+
|
|
331
|
+
seq_thresh, _ = _pick_backward_launch_params(p, nuft, n)
|
|
332
|
+
|
|
333
|
+
# Workspace: WS_XP0(1) + WS_XP1(P) + WS_XP2(P*P) + WS_HESS(P*P) +
|
|
334
|
+
# WS_XP1F(P) + WS_XP2F(P*P) + WS_SCRATCH(1)
|
|
335
|
+
ws_size = 1 + 3 * p_ce + 3 * p_ce * p_ce + 1
|
|
336
|
+
ws = torch.zeros(ws_size, dtype=torch.float64, device=device)
|
|
337
|
+
grad_out = torch.zeros(p, dtype=torch.float64, device=device)
|
|
338
|
+
# Allocate hess_out with padded stride (p_ce) to match Triton kernel indexing
|
|
339
|
+
hess_out = torch.zeros(p_ce * p_ce, dtype=torch.float64, device=device)
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
_efron_backward_scan_serial[(1,)](
|
|
343
|
+
X, e_eta,
|
|
344
|
+
enter_ptr_t, enter_ind_t,
|
|
345
|
+
exit_ptr_t, exit_ind_t,
|
|
346
|
+
fail_ptr_t, fail_ind_t,
|
|
347
|
+
ws, grad_out, hess_out,
|
|
348
|
+
n, p, nuft,
|
|
349
|
+
P=p_ce,
|
|
350
|
+
)
|
|
351
|
+
torch.cuda.synchronize()
|
|
352
|
+
except Exception:
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
return grad_out, -hess_out.view(p_ce, p_ce)[:p, :p]
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# compute_breslow_grad_hess_triton and _find_p_ce are imported from
|
|
359
|
+
# _cox_breslow_triton_kernel.py above (in the try/except block).
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Unsupervised learning estimators."""
|
|
2
|
+
|
|
3
|
+
from ._pca import PCA
|
|
4
|
+
from ._kmeans import KMeans
|
|
5
|
+
from ._dbscan import DBSCAN
|
|
6
|
+
from ._gmm import GaussianMixture
|
|
7
|
+
from ._nmf import NMF
|
|
8
|
+
from ._agglomerative import AgglomerativeClustering
|
|
9
|
+
from ._truncated_svd import TruncatedSVD
|
|
10
|
+
from ._minibatch_kmeans import MiniBatchKMeans
|
|
11
|
+
from ._incremental_pca import IncrementalPCA
|
|
12
|
+
from ._minibatch_nmf import MiniBatchNMF
|
|
13
|
+
from ._umap import UMAP
|
|
14
|
+
from ._tsne import TSNE
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PCA",
|
|
18
|
+
"KMeans",
|
|
19
|
+
"DBSCAN",
|
|
20
|
+
"GaussianMixture",
|
|
21
|
+
"NMF",
|
|
22
|
+
"AgglomerativeClustering",
|
|
23
|
+
"TruncatedSVD",
|
|
24
|
+
"MiniBatchKMeans",
|
|
25
|
+
"IncrementalPCA",
|
|
26
|
+
"MiniBatchNMF",
|
|
27
|
+
"UMAP",
|
|
28
|
+
"TSNE",
|
|
29
|
+
]
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""Agglomerative clustering."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy.cluster.hierarchy import fcluster, linkage
|
|
11
|
+
|
|
12
|
+
from statgpu._base import BaseEstimator
|
|
13
|
+
from statgpu._config import Device
|
|
14
|
+
from statgpu.unsupervised._utils import check_2d_array, reject_sparse, squared_euclidean_distances
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
DEFAULT_GPU_DISTANCE_LIMIT_BYTES = 1 << 30
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _gpu_distance_limit_bytes() -> int:
|
|
21
|
+
value = os.environ.get("STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES")
|
|
22
|
+
if value is None:
|
|
23
|
+
return DEFAULT_GPU_DISTANCE_LIMIT_BYTES
|
|
24
|
+
try:
|
|
25
|
+
return int(value)
|
|
26
|
+
except ValueError:
|
|
27
|
+
warnings.warn(
|
|
28
|
+
"Invalid STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES value; "
|
|
29
|
+
f"using default {DEFAULT_GPU_DISTANCE_LIMIT_BYTES} bytes.",
|
|
30
|
+
RuntimeWarning,
|
|
31
|
+
stacklevel=2,
|
|
32
|
+
)
|
|
33
|
+
return DEFAULT_GPU_DISTANCE_LIMIT_BYTES
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AgglomerativeClustering(BaseEstimator):
|
|
37
|
+
"""Exact dense agglomerative clustering."""
|
|
38
|
+
|
|
39
|
+
_GPU_DISTANCE_LIMIT_BYTES = _gpu_distance_limit_bytes()
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
n_clusters: int = 2,
|
|
44
|
+
linkage: str = "single",
|
|
45
|
+
metric: str = "euclidean",
|
|
46
|
+
device: Union[str, Device] = Device.AUTO,
|
|
47
|
+
n_jobs: Optional[int] = None,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
50
|
+
self.n_clusters = n_clusters
|
|
51
|
+
self.linkage = linkage
|
|
52
|
+
self.metric = metric
|
|
53
|
+
|
|
54
|
+
def _validate_params(self, n_samples: int):
|
|
55
|
+
if not isinstance(self.n_clusters, (int, np.integer)) or int(self.n_clusters) < 1:
|
|
56
|
+
raise ValueError("n_clusters must be a positive integer")
|
|
57
|
+
if int(self.n_clusters) > n_samples:
|
|
58
|
+
raise ValueError("n_clusters must be less than or equal to n_samples")
|
|
59
|
+
if self.linkage not in ("single", "complete", "average", "ward"):
|
|
60
|
+
raise ValueError("linkage must be one of: 'single', 'complete', 'average', 'ward'")
|
|
61
|
+
if self.metric != "euclidean":
|
|
62
|
+
raise NotImplementedError("AgglomerativeClustering only supports metric='euclidean'")
|
|
63
|
+
|
|
64
|
+
def _use_gpu_path(self) -> bool:
|
|
65
|
+
return self.device in (Device.CUDA, Device.TORCH)
|
|
66
|
+
|
|
67
|
+
def _check_gpu_memory(self, n_samples: int):
|
|
68
|
+
required = int(n_samples) * int(n_samples) * 8
|
|
69
|
+
if required > self._GPU_DISTANCE_LIMIT_BYTES:
|
|
70
|
+
limit_mb = self._GPU_DISTANCE_LIMIT_BYTES / (1024**2)
|
|
71
|
+
required_mb = required / (1024**2)
|
|
72
|
+
raise MemoryError(
|
|
73
|
+
"AgglomerativeClustering GPU exact path requires a dense "
|
|
74
|
+
f"distance matrix of about {required_mb:.1f} MiB, exceeding "
|
|
75
|
+
f"the configured limit {limit_mb:.1f} MiB. Use device='cpu' "
|
|
76
|
+
"or raise STATGPU_AGGLOMERATIVE_GPU_MAX_BYTES explicitly."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def _labels_from_children(n_samples: int, n_clusters: int, children: np.ndarray) -> np.ndarray:
|
|
81
|
+
clusters = {i: [i] for i in range(n_samples)}
|
|
82
|
+
next_id = n_samples
|
|
83
|
+
merges_to_apply = max(0, n_samples - int(n_clusters))
|
|
84
|
+
for left, right in children[:merges_to_apply]:
|
|
85
|
+
members = clusters.pop(int(left)) + clusters.pop(int(right))
|
|
86
|
+
clusters[next_id] = members
|
|
87
|
+
next_id += 1
|
|
88
|
+
|
|
89
|
+
labels = np.empty(n_samples, dtype=np.int64)
|
|
90
|
+
for label, members in enumerate(clusters.values()):
|
|
91
|
+
labels[np.asarray(members, dtype=np.int64)] = label
|
|
92
|
+
return labels
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _single_linkage_from_mst(
|
|
96
|
+
n_samples: int,
|
|
97
|
+
edge_parents: np.ndarray,
|
|
98
|
+
edge_children: np.ndarray,
|
|
99
|
+
edge_weights: np.ndarray,
|
|
100
|
+
):
|
|
101
|
+
order = np.argsort(edge_weights, kind="mergesort")
|
|
102
|
+
uf_parent = list(range(n_samples))
|
|
103
|
+
cluster_ids = list(range(n_samples))
|
|
104
|
+
children = np.empty((n_samples - 1, 2), dtype=np.int64)
|
|
105
|
+
distances = np.empty(n_samples - 1, dtype=np.float64)
|
|
106
|
+
|
|
107
|
+
def find(idx: int) -> int:
|
|
108
|
+
while uf_parent[idx] != idx:
|
|
109
|
+
uf_parent[idx] = uf_parent[uf_parent[idx]]
|
|
110
|
+
idx = uf_parent[idx]
|
|
111
|
+
return idx
|
|
112
|
+
|
|
113
|
+
merge_step = 0
|
|
114
|
+
for edge_idx in order:
|
|
115
|
+
left_root = find(int(edge_parents[edge_idx]))
|
|
116
|
+
right_root = find(int(edge_children[edge_idx]))
|
|
117
|
+
if left_root == right_root:
|
|
118
|
+
continue
|
|
119
|
+
children[merge_step] = (cluster_ids[left_root], cluster_ids[right_root])
|
|
120
|
+
distances[merge_step] = float(edge_weights[edge_idx])
|
|
121
|
+
uf_parent[right_root] = left_root
|
|
122
|
+
cluster_ids[left_root] = n_samples + merge_step
|
|
123
|
+
merge_step += 1
|
|
124
|
+
if merge_step == n_samples - 1:
|
|
125
|
+
break
|
|
126
|
+
|
|
127
|
+
return children, distances
|
|
128
|
+
|
|
129
|
+
def _fit_gpu_single(self, backend, X_arr, n_samples: int):
|
|
130
|
+
D = backend.sqrt(squared_euclidean_distances(backend, X_arr))
|
|
131
|
+
inf = float("inf")
|
|
132
|
+
indices = backend.arange(n_samples, dtype=backend.int64)
|
|
133
|
+
D[indices, indices] = inf
|
|
134
|
+
|
|
135
|
+
selected = backend.zeros(n_samples, dtype=backend.bool)
|
|
136
|
+
selected[0] = True
|
|
137
|
+
min_dist = backend.copy(D[0, :])
|
|
138
|
+
min_dist[0] = inf
|
|
139
|
+
nearest_parent = backend.zeros(n_samples, dtype=backend.int64)
|
|
140
|
+
|
|
141
|
+
edge_parents = np.empty(n_samples - 1, dtype=np.int64)
|
|
142
|
+
edge_children = np.empty(n_samples - 1, dtype=np.int64)
|
|
143
|
+
edge_weights = np.empty(n_samples - 1, dtype=np.float64)
|
|
144
|
+
|
|
145
|
+
for step in range(n_samples - 1):
|
|
146
|
+
child = int(float(backend.argmin(min_dist)))
|
|
147
|
+
edge_children[step] = child
|
|
148
|
+
edge_parents[step] = int(float(nearest_parent[child]))
|
|
149
|
+
edge_weights[step] = float(min_dist[child])
|
|
150
|
+
|
|
151
|
+
selected[child] = True
|
|
152
|
+
candidate = D[child, :]
|
|
153
|
+
update_mask = (candidate < min_dist) & (~selected)
|
|
154
|
+
nearest_parent[update_mask] = child
|
|
155
|
+
min_dist = backend.where(update_mask, candidate, min_dist)
|
|
156
|
+
min_dist[child] = inf
|
|
157
|
+
|
|
158
|
+
return self._single_linkage_from_mst(n_samples, edge_parents, edge_children, edge_weights)
|
|
159
|
+
|
|
160
|
+
def _fit_gpu(self, X):
|
|
161
|
+
backend = self._get_backend()
|
|
162
|
+
X_arr = self._to_array(X, backend=backend.name)
|
|
163
|
+
X_arr = backend.asarray(X_arr, dtype=backend.float64)
|
|
164
|
+
check_2d_array(X_arr)
|
|
165
|
+
n_samples, n_features = X_arr.shape
|
|
166
|
+
self._validate_params(n_samples)
|
|
167
|
+
self._check_gpu_memory(n_samples)
|
|
168
|
+
|
|
169
|
+
if n_samples == 1:
|
|
170
|
+
self.labels_ = np.zeros(1, dtype=np.int64)
|
|
171
|
+
self.children_ = np.empty((0, 2), dtype=np.int64)
|
|
172
|
+
self.distances_ = np.empty((0,), dtype=np.float64)
|
|
173
|
+
self.n_features_in_ = int(n_features)
|
|
174
|
+
self._backend_name = backend.name
|
|
175
|
+
self._fitted = True
|
|
176
|
+
return self
|
|
177
|
+
|
|
178
|
+
if self.linkage == "single" and backend.name in ("cupy", "torch"):
|
|
179
|
+
children, distances = self._fit_gpu_single(backend, X_arr, n_samples)
|
|
180
|
+
self.children_ = children
|
|
181
|
+
self.distances_ = distances
|
|
182
|
+
self.labels_ = self._labels_from_children(n_samples, int(self.n_clusters), children)
|
|
183
|
+
self.n_features_in_ = int(n_features)
|
|
184
|
+
self._backend_name = backend.name
|
|
185
|
+
self._fitted = True
|
|
186
|
+
return self
|
|
187
|
+
|
|
188
|
+
D = squared_euclidean_distances(backend, X_arr)
|
|
189
|
+
if self.linkage != "ward":
|
|
190
|
+
D = backend.sqrt(D)
|
|
191
|
+
inf = float("inf")
|
|
192
|
+
indices = backend.arange(n_samples, dtype=backend.int64)
|
|
193
|
+
D[indices, indices] = inf
|
|
194
|
+
|
|
195
|
+
children = np.empty((n_samples - 1, 2), dtype=np.int64)
|
|
196
|
+
distances = np.empty(n_samples - 1, dtype=np.float64)
|
|
197
|
+
cluster_ids = list(range(n_samples))
|
|
198
|
+
cluster_sizes = [1.0] * n_samples
|
|
199
|
+
cluster_sizes_backend = (
|
|
200
|
+
backend.asarray(cluster_sizes, dtype=backend.float64) if self.linkage == "ward" else None
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
for step in range(n_samples - 1):
|
|
204
|
+
flat_idx = int(float(backend.argmin(D)))
|
|
205
|
+
a = flat_idx // n_samples
|
|
206
|
+
b = flat_idx % n_samples
|
|
207
|
+
if b < a:
|
|
208
|
+
a, b = b, a
|
|
209
|
+
|
|
210
|
+
merge_value = float(D[a, b])
|
|
211
|
+
children[step] = (cluster_ids[a], cluster_ids[b])
|
|
212
|
+
distances[step] = np.sqrt(max(merge_value, 0.0)) if self.linkage == "ward" else merge_value
|
|
213
|
+
|
|
214
|
+
da = D[a, :]
|
|
215
|
+
db = D[b, :]
|
|
216
|
+
size_a = cluster_sizes[a]
|
|
217
|
+
size_b = cluster_sizes[b]
|
|
218
|
+
|
|
219
|
+
if self.linkage == "single":
|
|
220
|
+
updated = backend.minimum(da, db)
|
|
221
|
+
elif self.linkage == "complete":
|
|
222
|
+
if backend.name in ("cupy", "torch"):
|
|
223
|
+
backend.xp.maximum(da, db, out=da)
|
|
224
|
+
updated = da
|
|
225
|
+
else:
|
|
226
|
+
updated = backend.maximum(da, db)
|
|
227
|
+
elif self.linkage == "average":
|
|
228
|
+
if backend.name in ("cupy", "torch"):
|
|
229
|
+
da *= size_a
|
|
230
|
+
da += size_b * db
|
|
231
|
+
da /= size_a + size_b
|
|
232
|
+
updated = da
|
|
233
|
+
else:
|
|
234
|
+
updated = (size_a * da + size_b * db) / (size_a + size_b)
|
|
235
|
+
else:
|
|
236
|
+
total = size_a + size_b + cluster_sizes_backend
|
|
237
|
+
updated = (
|
|
238
|
+
((cluster_sizes_backend + size_a) / total) * da
|
|
239
|
+
+ ((cluster_sizes_backend + size_b) / total) * db
|
|
240
|
+
- (cluster_sizes_backend / total) * merge_value
|
|
241
|
+
)
|
|
242
|
+
updated = backend.maximum(updated, 0.0)
|
|
243
|
+
|
|
244
|
+
D[a, :] = updated
|
|
245
|
+
D[:, a] = updated
|
|
246
|
+
cluster_ids[a] = n_samples + step
|
|
247
|
+
cluster_sizes[a] += cluster_sizes[b]
|
|
248
|
+
cluster_sizes[b] = 0.0
|
|
249
|
+
if cluster_sizes_backend is not None:
|
|
250
|
+
cluster_sizes_backend[a] = cluster_sizes[a]
|
|
251
|
+
cluster_sizes_backend[b] = 0.0
|
|
252
|
+
D[b, :] = inf
|
|
253
|
+
D[:, b] = inf
|
|
254
|
+
D[a, a] = inf
|
|
255
|
+
|
|
256
|
+
self.children_ = children
|
|
257
|
+
self.distances_ = distances
|
|
258
|
+
self.labels_ = self._labels_from_children(n_samples, int(self.n_clusters), children)
|
|
259
|
+
self.n_features_in_ = int(n_features)
|
|
260
|
+
self._backend_name = backend.name
|
|
261
|
+
self._fitted = True
|
|
262
|
+
return self
|
|
263
|
+
|
|
264
|
+
def fit(self, X, y=None):
|
|
265
|
+
reject_sparse(X, "AgglomerativeClustering")
|
|
266
|
+
if self._use_gpu_path():
|
|
267
|
+
return self._fit_gpu(X)
|
|
268
|
+
|
|
269
|
+
X_arr = np.asarray(X, dtype=np.float64)
|
|
270
|
+
check_2d_array(X_arr)
|
|
271
|
+
n_samples, n_features = X_arr.shape
|
|
272
|
+
self._validate_params(n_samples)
|
|
273
|
+
|
|
274
|
+
if n_samples == 1:
|
|
275
|
+
children = np.empty((0, 2), dtype=np.int64)
|
|
276
|
+
distances = np.empty((0,), dtype=np.float64)
|
|
277
|
+
labels = np.zeros(1, dtype=np.int64)
|
|
278
|
+
else:
|
|
279
|
+
Z = linkage(X_arr, method=self.linkage, metric="euclidean")
|
|
280
|
+
children = Z[:, :2].astype(np.int64, copy=False)
|
|
281
|
+
distances = Z[:, 2].astype(np.float64, copy=False)
|
|
282
|
+
labels = fcluster(Z, t=int(self.n_clusters), criterion="maxclust").astype(np.int64) - 1
|
|
283
|
+
|
|
284
|
+
self.labels_ = labels
|
|
285
|
+
self.children_ = children
|
|
286
|
+
self.distances_ = distances
|
|
287
|
+
self.n_features_in_ = int(n_features)
|
|
288
|
+
self._backend_name = "numpy"
|
|
289
|
+
self._fitted = True
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
def fit_predict(self, X, y=None):
|
|
293
|
+
return self.fit(X, y=y).labels_
|
|
294
|
+
|
|
295
|
+
def predict(self, X):
|
|
296
|
+
raise NotImplementedError("AgglomerativeClustering does not support predict for unseen samples")
|
|
297
|
+
|
|
298
|
+
def get_params(self, deep=True):
|
|
299
|
+
params = super().get_params(deep=deep)
|
|
300
|
+
params.update(
|
|
301
|
+
{
|
|
302
|
+
"n_clusters": self.n_clusters,
|
|
303
|
+
"linkage": self.linkage,
|
|
304
|
+
"metric": self.metric,
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
return params
|