fastkmeans 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastkmeans/__init__.py +1 -1
- fastkmeans/kmeans.py +113 -75
- fastkmeans/triton_kernels.py +105 -0
- fastkmeans-0.3.0.dist-info/METADATA +72 -0
- fastkmeans-0.3.0.dist-info/RECORD +8 -0
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.3.0.dist-info}/WHEEL +1 -1
- fastkmeans-0.1.0.dist-info/METADATA +0 -13
- fastkmeans-0.1.0.dist-info/RECORD +0 -7
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.3.0.dist-info}/top_level.txt +0 -0
fastkmeans/__init__.py
CHANGED
fastkmeans/kmeans.py
CHANGED
@@ -3,25 +3,43 @@ import time
|
|
3
3
|
import torch
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
def _get_device(preset: str = None):
|
7
|
-
if preset:
|
8
|
-
|
9
|
-
if
|
10
|
-
|
6
|
+
def _get_device(preset: str | int | torch.device | None = None):
|
7
|
+
if isinstance(preset, torch.device):
|
8
|
+
return preset
|
9
|
+
if isinstance(preset, str):
|
10
|
+
return torch.device(preset)
|
11
|
+
if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
|
12
|
+
return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
|
13
|
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
14
|
+
return torch.device('mps')
|
15
|
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
16
|
+
return torch.device(f'xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}')
|
17
|
+
return torch.device('cpu')
|
18
|
+
|
19
|
+
|
20
|
+
def _is_bfloat16_supported(device:torch.device):
|
21
|
+
if device.type == 'cuda':
|
22
|
+
return torch.cuda.is_bf16_supported()
|
23
|
+
elif device.type == 'xpu' and hasattr(torch.xpu, 'is_bf16_supported'):
|
24
|
+
return torch.xpu.is_bf16_supported()
|
25
|
+
else:
|
26
|
+
return False
|
27
|
+
|
11
28
|
|
12
29
|
@torch.inference_mode()
|
13
30
|
def _kmeans_torch_double_chunked(
|
14
31
|
data: torch.Tensor,
|
15
32
|
data_norms: torch.Tensor,
|
16
33
|
k: int,
|
34
|
+
device: torch.device,
|
35
|
+
dtype: torch.dtype | None = None,
|
17
36
|
max_iters: int = 25,
|
18
37
|
tol: float = 1e-8,
|
19
|
-
device: str = None,
|
20
|
-
dtype: torch.dtype = None,
|
21
38
|
chunk_size_data: int = 50_000,
|
22
39
|
chunk_size_centroids: int = 10_000,
|
23
40
|
max_points_per_centroid: int = 256,
|
24
41
|
verbose: bool = False,
|
42
|
+
use_triton: bool | None = None,
|
25
43
|
):
|
26
44
|
"""
|
27
45
|
An efficient kmeans implementation that minimises OOM risks on modern hardware by using conversative double chunking.
|
@@ -32,6 +50,13 @@ def _kmeans_torch_double_chunked(
|
|
32
50
|
labels_cpu : torch.Tensor, shape (n_samples_used,), long
|
33
51
|
Where n_samples_used can be smaller than the original if subsampling occurred.
|
34
52
|
"""
|
53
|
+
|
54
|
+
if use_triton:
|
55
|
+
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
56
|
+
|
57
|
+
if dtype is None:
|
58
|
+
dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
|
59
|
+
|
35
60
|
n_samples_original, n_features = data.shape
|
36
61
|
n_samples = n_samples_original
|
37
62
|
|
@@ -47,19 +72,16 @@ def _kmeans_torch_double_chunked(
|
|
47
72
|
if n_samples < k:
|
48
73
|
raise ValueError(f"Number of training points ({n_samples}) is less than k ({k}).")
|
49
74
|
|
50
|
-
if dtype is None: dtype = torch.float16 if device == 'cuda' else torch.float32
|
51
|
-
|
52
75
|
# centroid init -- random is the only supported init
|
53
76
|
rand_indices = torch.randperm(n_samples)[:k]
|
54
77
|
centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
|
55
78
|
prev_centroids = centroids.clone()
|
56
79
|
|
57
|
-
labels = torch.empty(n_samples, dtype=torch.
|
58
|
-
|
80
|
+
labels = torch.empty(n_samples, dtype=torch.int64, device='cpu') # Keep labels on CPU
|
59
81
|
|
60
82
|
for iteration in range(max_iters):
|
61
83
|
iteration_start_time = time.time()
|
62
|
-
|
84
|
+
|
63
85
|
centroid_norms = (centroids ** 2).sum(dim=1)
|
64
86
|
cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
|
65
87
|
cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
|
@@ -71,58 +93,62 @@ def _kmeans_torch_double_chunked(
|
|
71
93
|
data_chunk = data[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
|
72
94
|
data_chunk_norms = data_norms[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
|
73
95
|
batch_size = data_chunk.size(0)
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
85
|
-
dist_chunk = dist_chunk.addmm_(
|
86
|
-
data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
|
96
|
+
best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
|
97
|
+
|
98
|
+
if use_triton:
|
99
|
+
chunked_kmeans_kernel(
|
100
|
+
data_chunk=data_chunk,
|
101
|
+
data_chunk_norms=data_chunk_norms,
|
102
|
+
centroids=centroids,
|
103
|
+
centroids_sqnorm=centroid_norms,
|
104
|
+
best_ids=best_ids,
|
87
105
|
)
|
106
|
+
else:
|
107
|
+
best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
|
108
|
+
c_start = 0
|
109
|
+
while c_start < k:
|
110
|
+
c_end = min(c_start + chunk_size_centroids, k)
|
111
|
+
centroid_chunk = centroids[c_start:c_end]
|
112
|
+
centroid_chunk_norms = centroid_norms[c_start:c_end]
|
88
113
|
|
89
|
-
|
90
|
-
|
91
|
-
best_dist[improved_mask] = local_min_vals[improved_mask]
|
92
|
-
best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
|
114
|
+
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
115
|
+
dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
|
93
116
|
|
94
|
-
|
117
|
+
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
118
|
+
improved_mask = local_min_vals < best_dist
|
119
|
+
best_dist[improved_mask] = local_min_vals[improved_mask]
|
120
|
+
best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
|
121
|
+
|
122
|
+
c_start = c_end
|
95
123
|
|
96
124
|
cluster_sums.index_add_(0, best_ids, data_chunk.float())
|
97
|
-
cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, dtype=torch.float32))
|
125
|
+
cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
|
98
126
|
|
99
127
|
labels[start_idx:end_idx] = best_ids.to('cpu', non_blocking=True)
|
100
128
|
start_idx = end_idx
|
101
129
|
|
102
|
-
new_centroids = torch.zeros_like(centroids, device=device, dtype=
|
130
|
+
new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
|
103
131
|
non_empty = (cluster_counts > 0)
|
104
|
-
new_centroids[non_empty] = (
|
105
|
-
cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)
|
106
|
-
)
|
132
|
+
new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
|
107
133
|
|
108
134
|
empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
|
109
135
|
if len(empty_ids) > 0:
|
110
136
|
reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device='cpu')
|
111
|
-
random_data = data[reinit_indices].to(device=device, dtype=
|
137
|
+
random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
|
112
138
|
new_centroids[empty_ids] = random_data
|
113
139
|
|
114
|
-
new_centroids = new_centroids.to(dtype=dtype)
|
115
|
-
|
116
140
|
shift = torch.norm(new_centroids - prev_centroids.to(new_centroids.device), dim=1).sum().item()
|
117
141
|
centroids = new_centroids
|
118
142
|
|
119
143
|
prev_centroids = centroids.clone()
|
120
|
-
|
144
|
+
|
121
145
|
iteration_time = time.time() - iteration_start_time
|
122
|
-
if verbose:
|
123
|
-
|
146
|
+
if verbose:
|
147
|
+
print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
|
148
|
+
|
124
149
|
if shift < tol:
|
125
|
-
if verbose:
|
150
|
+
if verbose:
|
151
|
+
print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
|
126
152
|
break
|
127
153
|
|
128
154
|
centroids_cpu = centroids.to('cpu', dtype=torch.float32)
|
@@ -155,6 +181,9 @@ class FastKMeans:
|
|
155
181
|
Chunk size along the data dimension for assignment/update steps.
|
156
182
|
chunk_size_centroids : int, default=10_000
|
157
183
|
Chunk size along the centroid dimension for assignment/update steps.
|
184
|
+
use_triton : bool | None, default=None
|
185
|
+
Use the fast Triton backend for the assignment/update steps.
|
186
|
+
If None, the Triton backend will be enabled for modern GPUs.
|
158
187
|
"""
|
159
188
|
|
160
189
|
def __init__(
|
@@ -168,33 +197,36 @@ class FastKMeans:
|
|
168
197
|
max_points_per_centroid: int = 256,
|
169
198
|
chunk_size_data: int = 50_000,
|
170
199
|
chunk_size_centroids: int = 10_000,
|
171
|
-
device: str = None,
|
200
|
+
device: str | int | torch.device | None = None,
|
172
201
|
dtype: torch.dtype = None,
|
173
202
|
pin_gpu_memory: bool = True,
|
174
203
|
verbose: bool = False,
|
175
204
|
nredo: int = 1, # for compatibility only
|
205
|
+
use_triton: bool | None = None,
|
176
206
|
):
|
177
207
|
self.d = d
|
178
208
|
self.k = k
|
179
209
|
self.niter = niter
|
180
210
|
self.tol = tol
|
181
|
-
self.gpu = gpu
|
182
211
|
self.seed = seed
|
183
212
|
self.max_points_per_centroid = max_points_per_centroid
|
184
213
|
self.chunk_size_data = chunk_size_data
|
185
214
|
self.chunk_size_centroids = chunk_size_centroids
|
215
|
+
self.device = _get_device("cpu" if gpu is False else device)
|
186
216
|
self.centroids = None
|
187
|
-
if device not in [None, 'cuda'] and self.gpu: print("Warning: device is set to 'cuda' but gpu is True, ignoring 'device' argument and setting it to 'cuda'!")
|
188
|
-
self.device = 'cuda' if self.gpu else device
|
189
217
|
self.dtype = dtype
|
190
218
|
self.pin_gpu_memory = pin_gpu_memory
|
191
|
-
if nredo != 1: raise ValueError("nredo must be 1, redos not currently supported")
|
192
219
|
self.verbose = verbose
|
220
|
+
if use_triton is not False:
|
221
|
+
use_triton = _is_bfloat16_supported(self.device) # assume triton is supported if GPU supports bfloat16
|
222
|
+
self.use_triton = use_triton
|
223
|
+
if nredo != 1:
|
224
|
+
raise ValueError("nredo must be 1, redos not currently supported")
|
193
225
|
|
194
226
|
def train(self, data: np.ndarray):
|
195
227
|
"""
|
196
228
|
Trains (fits) the KMeans model on the given data and sets `self.centroids`. Designed to mimic faiss's `train()` method.
|
197
|
-
|
229
|
+
|
198
230
|
Parameters
|
199
231
|
----------
|
200
232
|
data : np.ndarray of shape (n_samples, d), float32
|
@@ -224,6 +256,7 @@ class FastKMeans:
|
|
224
256
|
chunk_size_centroids=self.chunk_size_centroids,
|
225
257
|
max_points_per_centroid=self.max_points_per_centroid,
|
226
258
|
verbose=self.verbose,
|
259
|
+
use_triton=self.use_triton,
|
227
260
|
)
|
228
261
|
self.centroids = centroids.numpy()
|
229
262
|
|
@@ -242,6 +275,8 @@ class FastKMeans:
|
|
242
275
|
-------
|
243
276
|
labels : np.ndarray of shape (n_samples,), int64
|
244
277
|
"""
|
278
|
+
if self.use_triton:
|
279
|
+
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
245
280
|
if self.centroids is None:
|
246
281
|
raise RuntimeError("Must call train() or fit() before predict().")
|
247
282
|
|
@@ -250,11 +285,7 @@ class FastKMeans:
|
|
250
285
|
|
251
286
|
# We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
|
252
287
|
centroids_torch = torch.from_numpy(self.centroids)
|
253
|
-
|
254
|
-
if device == 'cpu' and self.gpu and torch.cuda.is_available():
|
255
|
-
device = 'cuda' # If user asked for GPU, put centroids there
|
256
|
-
|
257
|
-
centroids_torch = centroids_torch.to(device=device, dtype=torch.float32)
|
288
|
+
centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
|
258
289
|
centroid_norms = (centroids_torch ** 2).sum(dim=1)
|
259
290
|
|
260
291
|
n_samples = data_torch.shape[0]
|
@@ -263,30 +294,37 @@ class FastKMeans:
|
|
263
294
|
start_idx = 0
|
264
295
|
while start_idx < n_samples:
|
265
296
|
end_idx = min(start_idx + self.chunk_size_data, n_samples)
|
266
|
-
data_chunk = data_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
|
267
|
-
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
|
268
|
-
batch_size = data_chunk.size(0)
|
269
297
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
|
298
|
+
data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
|
299
|
+
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
|
300
|
+
batch_size = data_chunk.size(0)
|
301
|
+
best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
|
302
|
+
|
303
|
+
if self.use_triton:
|
304
|
+
chunked_kmeans_kernel(
|
305
|
+
data_chunk,
|
306
|
+
data_chunk_norms,
|
307
|
+
centroids_torch,
|
308
|
+
centroid_norms,
|
309
|
+
best_ids,
|
283
310
|
)
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
311
|
+
else:
|
312
|
+
best_dist = torch.full((batch_size,), float('inf'), device=self.device, dtype=torch.float32)
|
313
|
+
c_start = 0
|
314
|
+
k = centroids_torch.shape[0]
|
315
|
+
while c_start < k:
|
316
|
+
c_end = min(c_start + self.chunk_size_centroids, k)
|
317
|
+
centroid_chunk = centroids_torch[c_start:c_end]
|
318
|
+
centroid_chunk_norms = centroid_norms[c_start:c_end]
|
319
|
+
|
320
|
+
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
321
|
+
dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
|
322
|
+
|
323
|
+
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
324
|
+
improved_mask = local_min_vals < best_dist
|
325
|
+
best_dist[improved_mask] = local_min_vals[improved_mask]
|
326
|
+
best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
|
327
|
+
c_start = c_end
|
290
328
|
|
291
329
|
labels[start_idx:end_idx] = best_ids.to('cpu')
|
292
330
|
start_idx = end_idx
|
@@ -0,0 +1,105 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.heuristics(
|
7
|
+
{
|
8
|
+
"BLOCK_M": lambda a: 128 if a["D"] <= 128 else (64 if a["D"] <= 512 else 32),
|
9
|
+
"BLOCK_N": lambda a: 64 if a["D"] <= 128 else 32,
|
10
|
+
"num_warps": lambda a: 4 if a["D"] <= 64 else 8,
|
11
|
+
"num_stages": lambda a: 2 if a["D"] < 64 else 1,
|
12
|
+
}
|
13
|
+
)
|
14
|
+
@triton.jit
|
15
|
+
def _chunked_kmeans_kernel(
|
16
|
+
data_ptr, # [B, D], row-major
|
17
|
+
x_norm_ptr, # [B], precomputed L2 norms of data
|
18
|
+
centroids_ptr, # [C, D], row-major
|
19
|
+
centroids_sqnorm_ptr, # [C], precomputed L2 norms of centroids
|
20
|
+
best_ids_ptr, # [B], int32 (to store best centroid indices)
|
21
|
+
B, # number of data points
|
22
|
+
C, # number of centroids
|
23
|
+
D: tl.constexpr, # dimension, or number of features
|
24
|
+
BLOCK_M: tl.constexpr,
|
25
|
+
BLOCK_N: tl.constexpr,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Each Triton block processes BLOCK_M rows of data. The kernel:
|
29
|
+
1) loads those rows (and their precomputed norms from x_norm_ptr),
|
30
|
+
2) loops over all centroids in chunks of BLOCK_N,
|
31
|
+
3) computes distances, finds the best centroid,
|
32
|
+
4) writes out the best centroid index for each data point.
|
33
|
+
"""
|
34
|
+
# 1) Identify which data rows this block handles
|
35
|
+
block_id = tl.program_id(axis=0)
|
36
|
+
row_start = block_id * BLOCK_M
|
37
|
+
rows = row_start + tl.arange(0, BLOCK_M)
|
38
|
+
mask = rows < B
|
39
|
+
|
40
|
+
# 2) Load data rows and precomputed x_norm: shape [BLOCK_M, D]
|
41
|
+
row_offsets = rows[:, None] * D + tl.arange(0, D)
|
42
|
+
x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
|
43
|
+
|
44
|
+
# shape: [BLOCK_M]
|
45
|
+
x_norm = tl.load(x_norm_ptr + rows, mask=mask, other=0.0)
|
46
|
+
|
47
|
+
# Prepare "best distance" + "best index"
|
48
|
+
best_dist = tl.full([BLOCK_M], 1e38, dtype=tl.float32)
|
49
|
+
best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
|
50
|
+
|
51
|
+
# 3) Iterate over the centroids in chunks of BLOCK_N
|
52
|
+
for chunk in range(0, C, BLOCK_N):
|
53
|
+
cids = chunk + tl.arange(0, BLOCK_N)
|
54
|
+
c_mask = cids < C
|
55
|
+
|
56
|
+
# Load sub-block of centroids: shape [BLOCK_N, D]
|
57
|
+
c_offsets = cids[:, None] * D + tl.arange(0, D)
|
58
|
+
cvals = tl.load(centroids_ptr + c_offsets, mask=c_mask[:, None], other=0.0).to(x.dtype)
|
59
|
+
|
60
|
+
# Load centroid norms: shape [BLOCK_N]
|
61
|
+
c_sqnorm = tl.load(centroids_sqnorm_ptr + cids, mask=c_mask, other=0.0).to(x.dtype)
|
62
|
+
|
63
|
+
# Compute distance = x_norm + c_sqnorm - 2 * dot(x, c)
|
64
|
+
dots = tl.dot(x, tl.trans(cvals)) # shape [BLOCK_M, BLOCK_N]
|
65
|
+
dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
|
66
|
+
|
67
|
+
# Find the argmin along the BLOCK_N dimension
|
68
|
+
local_min_vals, local_min_idx = tl.min(dist_chunk, axis=1, return_indices=True)
|
69
|
+
|
70
|
+
improved = local_min_vals < best_dist
|
71
|
+
best_dist = tl.where(improved, local_min_vals, best_dist)
|
72
|
+
best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
|
73
|
+
|
74
|
+
# 4) Write out the best centroid indices
|
75
|
+
tl.store(best_ids_ptr + rows, best_idx, mask=mask)
|
76
|
+
|
77
|
+
|
78
|
+
def chunked_kmeans_kernel(
|
79
|
+
data_chunk: torch.Tensor,
|
80
|
+
data_chunk_norms: torch.Tensor,
|
81
|
+
centroids: torch.Tensor,
|
82
|
+
centroids_sqnorm: torch.Tensor,
|
83
|
+
best_ids: torch.Tensor,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Launches the Triton kernel to assign each point to its nearest centroid in one pass.
|
87
|
+
|
88
|
+
best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
|
89
|
+
"""
|
90
|
+
B, D = data_chunk.shape
|
91
|
+
C = centroids.shape[0]
|
92
|
+
|
93
|
+
def grid(meta):
|
94
|
+
return (triton.cdiv(B, meta["BLOCK_M"]),)
|
95
|
+
|
96
|
+
_chunked_kmeans_kernel[grid](
|
97
|
+
data_chunk,
|
98
|
+
data_chunk_norms,
|
99
|
+
centroids,
|
100
|
+
centroids_sqnorm,
|
101
|
+
best_ids,
|
102
|
+
B, # num_points
|
103
|
+
C, # num_centroids
|
104
|
+
D, # dimension, or number of features
|
105
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: fastkmeans
|
3
|
+
Version: 0.3.0
|
4
|
+
Summary: Add your description here
|
5
|
+
Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
6
|
+
Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
7
|
+
License: Apache-2.0
|
8
|
+
Requires-Python: >=3.8
|
9
|
+
Description-Content-Type: text/markdown
|
10
|
+
License-File: LICENSE
|
11
|
+
Requires-Dist: torch
|
12
|
+
Requires-Dist: numpy
|
13
|
+
Dynamic: license-file
|
14
|
+
|
15
|
+
# fastkmeans
|
16
|
+
|
17
|
+

|
18
|
+
[](https://twitter.com/bclavie)
|
19
|
+
<!-- [](https://pepy.tech/project/fastkmeans) -->
|
20
|
+
|
21
|
+
_A fast and efficient k-means implementation for PyTorch, with support for GPU and CPU._
|
22
|
+
|
23
|
+
---
|
24
|
+
|
25
|
+
Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" pytorch native k-means clustering. It's compatible with both CPU and GPU, matching or outspeeding `faiss` (except in multi-GPU settings), and it comes without any install woes, relying on just two dependencies you already have installed anyway: `torch` and `numpy`.
|
26
|
+
|
27
|
+
### Get started
|
28
|
+
|
29
|
+
```sh
|
30
|
+
[uv] pip install fastkmeans
|
31
|
+
```
|
32
|
+
|
33
|
+
... and that's all you need to do! `FastKMeans` is now ready to use.
|
34
|
+
|
35
|
+
### So what does this do?
|
36
|
+
|
37
|
+
There's very, very little to this library. It provides a single interface, `FastKMeans`, which you can use by importing it from `fastkmeans`. This interface has been designed to be a slot-in replacement for existing FAISS implementations, while being *mostly* sklearn-compliant as well. Effectively, this means that three methods are exposed:
|
38
|
+
|
39
|
+
- train(): mimics the FAISS API, training the model on a dataset.
|
40
|
+
- fit(): mimics the sklearn API, training the model on a dataset.
|
41
|
+
- predict(): mimics the sklearn API, use the trained clusters to predict where new points belong.
|
42
|
+
- fit_predict(): mimics the sklearn API, chaining the two calls above.
|
43
|
+
|
44
|
+
#### Behaviour
|
45
|
+
|
46
|
+
Whenever possible, the library will attempt to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see what the arguments are, as they're very straightforward. The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
|
47
|
+
|
48
|
+
#### Chunking
|
49
|
+
|
50
|
+
The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
|
51
|
+
|
52
|
+
### Why `fastkmeans`?
|
53
|
+
|
54
|
+
The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
|
55
|
+
|
56
|
+
- `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
|
57
|
+
- `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch: in fact, it's faster than this library! However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
|
58
|
+
- `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
|
59
|
+
|
60
|
+
There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
|
61
|
+
|
62
|
+
### Limitations
|
63
|
+
|
64
|
+
- On a few toy datasets & MNIST, `fastkmeans` reaches roughly the same NMI as `faiss` and `scikit-learn`, indicating that it creates at least somewhat coherent clusters. However, it's not extensively tested, especially in non-ColBERT uses, so your mileage may vary.
|
65
|
+
- The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
|
66
|
+
- The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
|
67
|
+
|
68
|
+
### Speed
|
69
|
+
|
70
|
+
Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, 131072 and 262144 clusters, each with w/ cluster_size*100 data points).
|
71
|
+
|
72
|
+

|
@@ -0,0 +1,8 @@
|
|
1
|
+
fastkmeans/__init__.py,sha256=EKvVznV3rTwe5Yz_tqlF_bJMZsh-5vLSjVvMtCx3Xhk,79
|
2
|
+
fastkmeans/kmeans.py,sha256=_FbU_avJPxIEfikhPMPgBibj9ckKej21iWJCm-YNEUM,13778
|
3
|
+
fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
|
4
|
+
fastkmeans-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
+
fastkmeans-0.3.0.dist-info/METADATA,sha256=KOyw3KsNPCdF_6spJl2SByB8mVtZ3I6GE1H8N2KMUpM,6791
|
6
|
+
fastkmeans-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
7
|
+
fastkmeans-0.3.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
+
fastkmeans-0.3.0.dist-info/RECORD,,
|
@@ -1,13 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: fastkmeans
|
3
|
-
Version: 0.1.0
|
4
|
-
Summary: Add your description here
|
5
|
-
Author-email: Ben Clavié <bc@answer.ai>
|
6
|
-
Maintainer-email: Ben Clavié <bc@answer.ai>
|
7
|
-
License: Apache-2.0
|
8
|
-
Requires-Python: >=3.8
|
9
|
-
Description-Content-Type: text/markdown
|
10
|
-
License-File: LICENSE
|
11
|
-
Requires-Dist: torch
|
12
|
-
Requires-Dist: numpy
|
13
|
-
Dynamic: license-file
|
@@ -1,7 +0,0 @@
|
|
1
|
-
fastkmeans/__init__.py,sha256=eIDvwiAarlwyo4n6AvHsV3cPqkbLuILUE1gFAjpzItY,79
|
2
|
-
fastkmeans/kmeans.py,sha256=HO44oEOyI8esRDpUJYIq6tHXXVgysSRazoMSIJ2HkX8,12016
|
3
|
-
fastkmeans-0.1.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
-
fastkmeans-0.1.0.dist-info/METADATA,sha256=Fy5bypSuzqwdPs-Qw2fTJ8RZn2b9AKj9POyvv4U7gkk,344
|
5
|
-
fastkmeans-0.1.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
6
|
-
fastkmeans-0.1.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
7
|
-
fastkmeans-0.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|