fastkmeans 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastkmeans/__init__.py +1 -1
- fastkmeans/kmeans.py +57 -42
- fastkmeans/triton_kernels.py +96 -69
- {fastkmeans-0.3.0.dist-info → fastkmeans-0.4.0.dist-info}/METADATA +31 -5
- fastkmeans-0.4.0.dist-info/RECORD +8 -0
- {fastkmeans-0.3.0.dist-info → fastkmeans-0.4.0.dist-info}/WHEEL +1 -1
- fastkmeans-0.3.0.dist-info/RECORD +0 -8
- {fastkmeans-0.3.0.dist-info → fastkmeans-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {fastkmeans-0.3.0.dist-info → fastkmeans-0.4.0.dist-info}/top_level.txt +0 -0
fastkmeans/__init__.py
CHANGED
fastkmeans/kmeans.py
CHANGED
@@ -3,24 +3,33 @@ import time
|
|
3
3
|
import torch
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
+
try:
|
7
|
+
from fastkmeans.triton_kernels import triton_kmeans
|
8
|
+
|
9
|
+
HAS_TRITON = True
|
10
|
+
except ImportError:
|
11
|
+
triton_kmeans = None
|
12
|
+
HAS_TRITON = False
|
13
|
+
|
14
|
+
|
6
15
|
def _get_device(preset: str | int | torch.device | None = None):
|
7
16
|
if isinstance(preset, torch.device):
|
8
17
|
return preset
|
9
18
|
if isinstance(preset, str):
|
10
19
|
return torch.device(preset)
|
11
|
-
if torch.cuda.is_available():
|
20
|
+
if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
|
12
21
|
return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
|
13
|
-
if hasattr(torch.backends,
|
14
|
-
return torch.device(
|
15
|
-
if hasattr(torch,
|
16
|
-
return torch.device(f
|
17
|
-
return torch.device(
|
22
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
23
|
+
return torch.device("mps")
|
24
|
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
25
|
+
return torch.device(f"xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}")
|
26
|
+
return torch.device("cpu")
|
18
27
|
|
19
28
|
|
20
|
-
def _is_bfloat16_supported(device:torch.device):
|
21
|
-
if device.type ==
|
29
|
+
def _is_bfloat16_supported(device: torch.device):
|
30
|
+
if device.type == "cuda":
|
22
31
|
return torch.cuda.is_bf16_supported()
|
23
|
-
elif device.type ==
|
32
|
+
elif device.type == "xpu" and hasattr(torch.xpu, "is_bf16_supported"):
|
24
33
|
return torch.xpu.is_bf16_supported()
|
25
34
|
else:
|
26
35
|
return False
|
@@ -52,10 +61,11 @@ def _kmeans_torch_double_chunked(
|
|
52
61
|
"""
|
53
62
|
|
54
63
|
if use_triton:
|
55
|
-
|
64
|
+
if not HAS_TRITON:
|
65
|
+
raise ImportError("Triton is not available. Please install Triton and try again.")
|
56
66
|
|
57
67
|
if dtype is None:
|
58
|
-
dtype = torch.float16 if device.type in [
|
68
|
+
dtype = torch.float16 if device.type in ["cuda", "xpu"] else torch.float32
|
59
69
|
|
60
70
|
n_samples_original, n_features = data.shape
|
61
71
|
n_samples = n_samples_original
|
@@ -77,12 +87,12 @@ def _kmeans_torch_double_chunked(
|
|
77
87
|
centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
|
78
88
|
prev_centroids = centroids.clone()
|
79
89
|
|
80
|
-
labels = torch.empty(n_samples, dtype=torch.int64, device=
|
90
|
+
labels = torch.empty(n_samples, dtype=torch.int64, device="cpu") # Keep labels on CPU
|
81
91
|
|
82
92
|
for iteration in range(max_iters):
|
83
93
|
iteration_start_time = time.time()
|
84
94
|
|
85
|
-
centroid_norms = (centroids
|
95
|
+
centroid_norms = (centroids**2).sum(dim=1)
|
86
96
|
cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
|
87
97
|
cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
|
88
98
|
|
@@ -96,7 +106,7 @@ def _kmeans_torch_double_chunked(
|
|
96
106
|
best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
|
97
107
|
|
98
108
|
if use_triton:
|
99
|
-
|
109
|
+
triton_kmeans(
|
100
110
|
data_chunk=data_chunk,
|
101
111
|
data_chunk_norms=data_chunk_norms,
|
102
112
|
centroids=centroids,
|
@@ -104,7 +114,7 @@ def _kmeans_torch_double_chunked(
|
|
104
114
|
best_ids=best_ids,
|
105
115
|
)
|
106
116
|
else:
|
107
|
-
best_dist = torch.full((batch_size,), float(
|
117
|
+
best_dist = torch.full((batch_size,), float("inf"), device=device, dtype=dtype)
|
108
118
|
c_start = 0
|
109
119
|
while c_start < k:
|
110
120
|
c_end = min(c_start + chunk_size_centroids, k)
|
@@ -117,23 +127,23 @@ def _kmeans_torch_double_chunked(
|
|
117
127
|
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
118
128
|
improved_mask = local_min_vals < best_dist
|
119
129
|
best_dist[improved_mask] = local_min_vals[improved_mask]
|
120
|
-
best_ids[improved_mask] =
|
130
|
+
best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
|
121
131
|
|
122
132
|
c_start = c_end
|
123
133
|
|
124
134
|
cluster_sums.index_add_(0, best_ids, data_chunk.float())
|
125
135
|
cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
|
126
136
|
|
127
|
-
labels[start_idx:end_idx] = best_ids.to(
|
137
|
+
labels[start_idx:end_idx] = best_ids.to("cpu", non_blocking=True)
|
128
138
|
start_idx = end_idx
|
129
139
|
|
130
140
|
new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
|
131
|
-
non_empty =
|
141
|
+
non_empty = cluster_counts > 0
|
132
142
|
new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
|
133
143
|
|
134
144
|
empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
|
135
145
|
if len(empty_ids) > 0:
|
136
|
-
reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device=
|
146
|
+
reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device="cpu")
|
137
147
|
random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
|
138
148
|
new_centroids[empty_ids] = random_data
|
139
149
|
|
@@ -144,14 +154,16 @@ def _kmeans_torch_double_chunked(
|
|
144
154
|
|
145
155
|
iteration_time = time.time() - iteration_start_time
|
146
156
|
if verbose:
|
147
|
-
print(
|
157
|
+
print(
|
158
|
+
f"Iteration {iteration + 1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}"
|
159
|
+
)
|
148
160
|
|
149
161
|
if shift < tol:
|
150
162
|
if verbose:
|
151
|
-
print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
|
163
|
+
print(f"Converged after {iteration + 1} iterations (shift: {shift:.6f} < tol: {tol})")
|
152
164
|
break
|
153
165
|
|
154
|
-
centroids_cpu = centroids.to(
|
166
|
+
centroids_cpu = centroids.to("cpu", dtype=torch.float32)
|
155
167
|
return centroids_cpu, labels
|
156
168
|
|
157
169
|
|
@@ -177,9 +189,9 @@ class FastKMeans:
|
|
177
189
|
max_points_per_centroid : int, optional, default=1_000_000_000
|
178
190
|
If n_samples > k * max_points_per_centroid, the data will be subsampled to exactly
|
179
191
|
k * max_points_per_centroid points before clustering.
|
180
|
-
chunk_size_data : int, default=
|
192
|
+
chunk_size_data : int, default=10,2400
|
181
193
|
Chunk size along the data dimension for assignment/update steps.
|
182
|
-
chunk_size_centroids : int, default=
|
194
|
+
chunk_size_centroids : int, default=10,240
|
183
195
|
Chunk size along the centroid dimension for assignment/update steps.
|
184
196
|
use_triton : bool | None, default=None
|
185
197
|
Use the fast Triton backend for the assignment/update steps.
|
@@ -194,14 +206,14 @@ class FastKMeans:
|
|
194
206
|
tol: float = 1e-8,
|
195
207
|
gpu: bool = True,
|
196
208
|
seed: int = 0,
|
197
|
-
max_points_per_centroid: int =
|
198
|
-
chunk_size_data: int =
|
199
|
-
chunk_size_centroids: int =
|
209
|
+
max_points_per_centroid: int = 1_000_000_000,
|
210
|
+
chunk_size_data: int = 51_200,
|
211
|
+
chunk_size_centroids: int = 10_240,
|
200
212
|
device: str | int | torch.device | None = None,
|
201
213
|
dtype: torch.dtype = None,
|
202
214
|
pin_gpu_memory: bool = True,
|
203
215
|
verbose: bool = False,
|
204
|
-
nredo: int = 1,
|
216
|
+
nredo: int = 1, # for compatibility only
|
205
217
|
use_triton: bool | None = None,
|
206
218
|
):
|
207
219
|
self.d = d
|
@@ -217,8 +229,11 @@ class FastKMeans:
|
|
217
229
|
self.dtype = dtype
|
218
230
|
self.pin_gpu_memory = pin_gpu_memory
|
219
231
|
self.verbose = verbose
|
220
|
-
if use_triton is
|
221
|
-
|
232
|
+
if use_triton is None:
|
233
|
+
# assume triton kernel is supported if GPU supports bfloat16
|
234
|
+
use_triton = HAS_TRITON and _is_bfloat16_supported(self.device)
|
235
|
+
if use_triton and not HAS_TRITON:
|
236
|
+
raise ValueError("Triton is not available. Please install Triton and try again.")
|
222
237
|
self.use_triton = use_triton
|
223
238
|
if nredo != 1:
|
224
239
|
raise ValueError("nredo must be 1, redos not currently supported")
|
@@ -237,10 +252,10 @@ class FastKMeans:
|
|
237
252
|
|
238
253
|
# Move data to PyTorch CPU Tensor
|
239
254
|
data_torch = torch.from_numpy(data)
|
240
|
-
data_norms_torch = (data_torch
|
255
|
+
data_norms_torch = (data_torch**2).sum(dim=1)
|
241
256
|
|
242
257
|
device = _get_device(self.device)
|
243
|
-
if device ==
|
258
|
+
if device == "cuda" and self.pin_gpu_memory:
|
244
259
|
data_torch = data_torch.pin_memory()
|
245
260
|
data_norms_torch = data_norms_torch.pin_memory()
|
246
261
|
|
@@ -275,33 +290,33 @@ class FastKMeans:
|
|
275
290
|
-------
|
276
291
|
labels : np.ndarray of shape (n_samples,), int64
|
277
292
|
"""
|
278
|
-
if self.use_triton:
|
279
|
-
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
280
293
|
if self.centroids is None:
|
281
294
|
raise RuntimeError("Must call train() or fit() before predict().")
|
282
295
|
|
283
296
|
data_torch = torch.from_numpy(data)
|
284
|
-
data_norms_torch = (data_torch
|
297
|
+
data_norms_torch = (data_torch**2).sum(dim=1)
|
285
298
|
|
286
299
|
# We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
|
287
300
|
centroids_torch = torch.from_numpy(self.centroids)
|
288
301
|
centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
|
289
|
-
centroid_norms = (centroids_torch
|
302
|
+
centroid_norms = (centroids_torch**2).sum(dim=1)
|
290
303
|
|
291
304
|
n_samples = data_torch.shape[0]
|
292
|
-
labels = torch.empty(n_samples, dtype=torch.long, device=
|
305
|
+
labels = torch.empty(n_samples, dtype=torch.long, device="cpu")
|
293
306
|
|
294
307
|
start_idx = 0
|
295
308
|
while start_idx < n_samples:
|
296
309
|
end_idx = min(start_idx + self.chunk_size_data, n_samples)
|
297
310
|
|
298
311
|
data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
|
299
|
-
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(
|
312
|
+
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(
|
313
|
+
device=self.device, dtype=torch.float32, non_blocking=True
|
314
|
+
)
|
300
315
|
batch_size = data_chunk.size(0)
|
301
316
|
best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
|
302
317
|
|
303
318
|
if self.use_triton:
|
304
|
-
|
319
|
+
triton_kmeans(
|
305
320
|
data_chunk,
|
306
321
|
data_chunk_norms,
|
307
322
|
centroids_torch,
|
@@ -309,7 +324,7 @@ class FastKMeans:
|
|
309
324
|
best_ids,
|
310
325
|
)
|
311
326
|
else:
|
312
|
-
best_dist = torch.full((batch_size,), float(
|
327
|
+
best_dist = torch.full((batch_size,), float("inf"), device=self.device, dtype=torch.float32)
|
313
328
|
c_start = 0
|
314
329
|
k = centroids_torch.shape[0]
|
315
330
|
while c_start < k:
|
@@ -323,10 +338,10 @@ class FastKMeans:
|
|
323
338
|
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
324
339
|
improved_mask = local_min_vals < best_dist
|
325
340
|
best_dist[improved_mask] = local_min_vals[improved_mask]
|
326
|
-
best_ids[improved_mask] =
|
341
|
+
best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
|
327
342
|
c_start = c_end
|
328
343
|
|
329
|
-
labels[start_idx:end_idx] = best_ids.to(
|
344
|
+
labels[start_idx:end_idx] = best_ids.to("cpu")
|
330
345
|
start_idx = end_idx
|
331
346
|
|
332
347
|
return labels.numpy()
|
fastkmeans/triton_kernels.py
CHANGED
@@ -1,105 +1,132 @@
|
|
1
|
+
from contextlib import contextmanager
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import triton
|
3
5
|
import triton.language as tl
|
4
6
|
|
5
7
|
|
8
|
+
@contextmanager
|
9
|
+
def device_guard(tensor: torch.Tensor):
|
10
|
+
"""Context manager to ensure that the Triton kernel launches on the correct device."""
|
11
|
+
if tensor.device.type == "cuda": # NVIDIA or AMD/ROCm
|
12
|
+
with torch.cuda.device_of(tensor):
|
13
|
+
yield
|
14
|
+
elif tensor.device.type == "xpu": # Intel GPUs
|
15
|
+
with torch.xpu.device_of(tensor):
|
16
|
+
yield
|
17
|
+
else: # CPU or other back-ends
|
18
|
+
yield
|
19
|
+
|
20
|
+
|
6
21
|
@triton.heuristics(
|
7
22
|
{
|
8
|
-
"BLOCK_M": lambda
|
9
|
-
"BLOCK_N": lambda
|
10
|
-
"
|
11
|
-
"
|
23
|
+
"BLOCK_M": lambda x: 128 if x["D"] <= 384 else 64,
|
24
|
+
"BLOCK_N": lambda x: 128 if x["D"] <= 384 else 64,
|
25
|
+
"BLOCK_K": lambda x: 16 if x["D"] <= 32 or x["D"] > 384 else 32,
|
26
|
+
"GROUP_SIZE_M": lambda x: 8 if x["D"] <= 32 else 16,
|
27
|
+
"num_warps": lambda x: 4,
|
12
28
|
}
|
13
29
|
)
|
14
30
|
@triton.jit
|
15
|
-
def
|
16
|
-
|
17
|
-
x_norm_ptr,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
31
|
+
def _kmeans_kernel(
|
32
|
+
x_ptr,
|
33
|
+
x_norm_ptr,
|
34
|
+
c_ptr,
|
35
|
+
c_norm_ptr,
|
36
|
+
best_dist_ptr,
|
37
|
+
best_idx_ptr,
|
38
|
+
B,
|
39
|
+
C,
|
40
|
+
D: tl.constexpr,
|
24
41
|
BLOCK_M: tl.constexpr,
|
25
42
|
BLOCK_N: tl.constexpr,
|
43
|
+
BLOCK_K: tl.constexpr,
|
44
|
+
GROUP_SIZE_M: tl.constexpr,
|
26
45
|
):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
46
|
+
# Map flat CTA id to (pid_m, pid_n) in “grouped” launch order
|
47
|
+
pid = tl.program_id(axis=0)
|
48
|
+
|
49
|
+
num_pid_m = tl.cdiv(B, BLOCK_M) # row-tiles
|
50
|
+
num_pid_n = tl.cdiv(C, BLOCK_N) # centroid-tiles
|
51
|
+
|
52
|
+
# Super-group into GROUP_SIZE_M blocks to minimize loading from global memory
|
53
|
+
num_pid_in_grp = GROUP_SIZE_M * num_pid_n
|
54
|
+
first_pid_m = (pid // num_pid_in_grp) * GROUP_SIZE_M
|
55
|
+
group_rows = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
56
|
+
|
57
|
+
pid_m = first_pid_m + ((pid % num_pid_in_grp) % group_rows) # row-tile index
|
58
|
+
pid_n = (pid % num_pid_in_grp) // group_rows # centroid-tile index
|
59
|
+
|
60
|
+
row_start = pid_m * BLOCK_M
|
61
|
+
col_start = pid_n * BLOCK_N
|
62
|
+
|
37
63
|
rows = row_start + tl.arange(0, BLOCK_M)
|
38
|
-
|
64
|
+
cols = col_start + tl.arange(0, BLOCK_N)
|
39
65
|
|
40
|
-
|
41
|
-
|
42
|
-
x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
|
66
|
+
row_mask = rows < B
|
67
|
+
col_mask = cols < C
|
43
68
|
|
44
|
-
#
|
45
|
-
|
69
|
+
# load norms
|
70
|
+
x_n = tl.load(x_norm_ptr + rows, mask=row_mask, other=0.0) # [BM]
|
71
|
+
c_n = tl.load(c_norm_ptr + cols, mask=col_mask, other=0.0) # [BN]
|
46
72
|
|
47
|
-
#
|
48
|
-
|
49
|
-
best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
|
73
|
+
# pipelined K‑loop, will hold partial dot‑products in registers
|
74
|
+
dot_acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
|
50
75
|
|
51
|
-
#
|
52
|
-
for
|
53
|
-
|
54
|
-
c_mask = cids < C
|
76
|
+
# compute matmul tiled across SMs
|
77
|
+
for k0 in range(0, D, BLOCK_K):
|
78
|
+
k_range = k0 + tl.arange(0, BLOCK_K)
|
55
79
|
|
56
|
-
#
|
57
|
-
|
58
|
-
|
80
|
+
# load X slice
|
81
|
+
x_ptrs = x_ptr + rows[:, None] * D + k_range[None, :]
|
82
|
+
xk = tl.load(x_ptrs, mask=row_mask[:, None]).to(tl.float16)
|
59
83
|
|
60
|
-
#
|
61
|
-
|
84
|
+
# load C slice
|
85
|
+
c_ptrs = c_ptr + cols[:, None] * D + k_range[None, :]
|
86
|
+
ck = tl.load(c_ptrs, mask=col_mask[:, None]).to(tl.float16)
|
62
87
|
|
63
|
-
#
|
64
|
-
|
65
|
-
dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
|
88
|
+
# accumulate
|
89
|
+
dot_acc += tl.dot(xk, tl.trans(ck), out_dtype=tl.float32)
|
66
90
|
|
67
|
-
|
68
|
-
|
91
|
+
# finish distance formula
|
92
|
+
dist = tl.fma(dot_acc, -2.0, x_n[:, None] + c_n[None, :]) # [BM, BN]
|
69
93
|
|
70
|
-
|
71
|
-
|
72
|
-
best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
|
94
|
+
# local arg‑min (inside this tile)
|
95
|
+
tile_min, tile_idx = tl.min(dist, axis=1, return_indices=True)
|
73
96
|
|
74
|
-
#
|
75
|
-
tl.
|
97
|
+
# compete with global best using atomics
|
98
|
+
prev = tl.atomic_min(best_dist_ptr + rows, tile_min, mask=row_mask)
|
99
|
+
improved = tile_min < prev
|
76
100
|
|
101
|
+
# update best_ids
|
102
|
+
tl.store(best_idx_ptr + rows, tl.where(improved, col_start + tile_idx, tl.load(best_idx_ptr + rows)), mask=row_mask)
|
77
103
|
|
78
|
-
|
104
|
+
|
105
|
+
def triton_kmeans(
|
79
106
|
data_chunk: torch.Tensor,
|
80
107
|
data_chunk_norms: torch.Tensor,
|
81
108
|
centroids: torch.Tensor,
|
82
109
|
centroids_sqnorm: torch.Tensor,
|
83
110
|
best_ids: torch.Tensor,
|
84
111
|
):
|
85
|
-
"""
|
86
|
-
Launches the Triton kernel to assign each point to its nearest centroid in one pass.
|
87
|
-
|
88
|
-
best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
|
89
|
-
"""
|
90
112
|
B, D = data_chunk.shape
|
91
113
|
C = centroids.shape[0]
|
114
|
+
best_dist = torch.full((B,), 1e38, device=data_chunk.device, dtype=torch.float32)
|
92
115
|
|
93
116
|
def grid(meta):
|
94
|
-
return (triton.cdiv(B, meta["BLOCK_M"]),)
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
117
|
+
return (triton.cdiv(B, meta["BLOCK_M"]) * triton.cdiv(C, meta["BLOCK_N"]),) # 1D grid
|
118
|
+
|
119
|
+
# Without this Triton always tries to launch from device:0 and we get
|
120
|
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
121
|
+
with device_guard(data_chunk):
|
122
|
+
_kmeans_kernel[grid](
|
123
|
+
data_chunk,
|
124
|
+
data_chunk_norms,
|
125
|
+
centroids,
|
126
|
+
centroids_sqnorm,
|
127
|
+
best_dist,
|
128
|
+
best_ids,
|
129
|
+
B,
|
130
|
+
C,
|
131
|
+
D,
|
132
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: fastkmeans
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Add your description here
|
5
5
|
Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
6
6
|
Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
@@ -22,7 +22,7 @@ _A fast and efficient k-means implementation for PyTorch, with support for GPU a
|
|
22
22
|
|
23
23
|
---
|
24
24
|
|
25
|
-
Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough"
|
25
|
+
Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" PyTorch native k-means clustering. It's compatible with any PyTorch-compatible CPU or GPU, matching or outperforming `faiss` by ~4-5× on a single GPU, and is without install woes, relying on just two dependencies you already have installed: `torch` and `numpy`.
|
26
26
|
|
27
27
|
### Get started
|
28
28
|
|
@@ -43,18 +43,22 @@ There's very, very little to this library. It provides a single interface, `Fast
|
|
43
43
|
|
44
44
|
#### Behaviour
|
45
45
|
|
46
|
-
Whenever possible, the library
|
46
|
+
Whenever possible, the library attempts to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/main/fastkmeans/kmeans.py#L170) to see what the arguments are, as they are straightforward.
|
47
|
+
|
48
|
+
The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
|
47
49
|
|
48
50
|
#### Chunking
|
49
51
|
|
50
52
|
The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
|
51
53
|
|
54
|
+
> Note: See the Triton section below for triton specific chunking details.
|
55
|
+
|
52
56
|
### Why `fastkmeans`?
|
53
57
|
|
54
58
|
The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
|
55
59
|
|
56
60
|
- `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
|
57
|
-
- `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch
|
61
|
+
- `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch. However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
|
58
62
|
- `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
|
59
63
|
|
60
64
|
There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
|
@@ -65,8 +69,30 @@ There are some libraries (such as NVidia's own implementations), but they again
|
|
65
69
|
- The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
|
66
70
|
- The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
|
67
71
|
|
72
|
+
### Triton Kernel
|
73
|
+
|
74
|
+
`fastkmeans`'s Triton kmeans kernel is ~4-5 times faster than single-GPU `faiss` or `fastkmeans`'s PyTorch backend. On a modern GPU (Ampere or newer), the Triton backend is enabled by default.
|
75
|
+
|
76
|
+
While the Triton kernel uses significantly less memory than the PyTorch implementation, increasing the chunk size above 512K can result in slower performance.
|
77
|
+
|
68
78
|
### Speed
|
69
79
|
|
70
|
-
Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536,
|
80
|
+
Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, and 131072 clusters, each with w/ cluster_size*100 data points).
|
71
81
|
|
72
82
|

|
83
|
+
|
84
|
+
#### Benchmarking
|
85
|
+
|
86
|
+
To benchmark `fastkmeans` against `faiss` on your own machine, install faiss and PyTorch 2.5 via the `bench_env.yaml` Conda environment:
|
87
|
+
|
88
|
+
```bash
|
89
|
+
conda env create -f bench_env.yaml
|
90
|
+
conda activate fastkmeans
|
91
|
+
pip install fastkmeans
|
92
|
+
```
|
93
|
+
|
94
|
+
Then, run the benchmark script:
|
95
|
+
|
96
|
+
```bash
|
97
|
+
CUDA_VISIBLE_DEVICES=0 python speedbench.py --do-faiss --do-fastkmeans --do-fastkmeans-triton --do-evals
|
98
|
+
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
fastkmeans/__init__.py,sha256=NcMnvnLDRqvTrAXfvVJB0M845Fe8imRzI90dAOlc3MY,79
|
2
|
+
fastkmeans/kmeans.py,sha256=l8W79LZszDiDSS8H2jFke8Wp_blFQdslCkr4Oo2CKik,14099
|
3
|
+
fastkmeans/triton_kernels.py,sha256=vwKUApwKsGWAhN9gMh9rWzandWoEt6rV187xLAW2t3U,4119
|
4
|
+
fastkmeans-0.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
+
fastkmeans-0.4.0.dist-info/METADATA,sha256=wKNRR_pxOUxTM4cAH2jCLnt7YgUZYYAc1kRG-4YZ8eE,7552
|
6
|
+
fastkmeans-0.4.0.dist-info/WHEEL,sha256=GHB6lJx2juba1wDgXDNlMTyM13ckjBMKf-OnwgKOCtA,91
|
7
|
+
fastkmeans-0.4.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
+
fastkmeans-0.4.0.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
fastkmeans/__init__.py,sha256=EKvVznV3rTwe5Yz_tqlF_bJMZsh-5vLSjVvMtCx3Xhk,79
|
2
|
-
fastkmeans/kmeans.py,sha256=_FbU_avJPxIEfikhPMPgBibj9ckKej21iWJCm-YNEUM,13778
|
3
|
-
fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
|
4
|
-
fastkmeans-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
-
fastkmeans-0.3.0.dist-info/METADATA,sha256=KOyw3KsNPCdF_6spJl2SByB8mVtZ3I6GE1H8N2KMUpM,6791
|
6
|
-
fastkmeans-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
7
|
-
fastkmeans-0.3.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
-
fastkmeans-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|