fastkmeans 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastkmeans/__init__.py +1 -1
- fastkmeans/kmeans.py +110 -75
- fastkmeans/triton_kernels.py +105 -0
- fastkmeans-0.2.0.dist-info/METADATA +72 -0
- fastkmeans-0.2.0.dist-info/RECORD +8 -0
- fastkmeans-0.1.0.dist-info/METADATA +0 -13
- fastkmeans-0.1.0.dist-info/RECORD +0 -7
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.2.0.dist-info}/WHEEL +0 -0
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {fastkmeans-0.1.0.dist-info → fastkmeans-0.2.0.dist-info}/top_level.txt +0 -0
fastkmeans/__init__.py
CHANGED
fastkmeans/kmeans.py
CHANGED
@@ -3,25 +3,45 @@ import time
|
|
3
3
|
import torch
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
if
|
10
|
-
|
6
|
+
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
7
|
+
|
8
|
+
def _get_device(preset: str | int | torch.device | None = None):
|
9
|
+
if isinstance(preset, torch.device):
|
10
|
+
return preset
|
11
|
+
if isinstance(preset, str):
|
12
|
+
return torch.device(preset)
|
13
|
+
if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
|
14
|
+
return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
|
15
|
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
16
|
+
return torch.device('mps')
|
17
|
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
18
|
+
return torch.device(f'xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}')
|
19
|
+
return torch.device('cpu')
|
20
|
+
|
21
|
+
|
22
|
+
def _is_bfloat16_supported(device:torch.device):
|
23
|
+
if device.type == 'cuda':
|
24
|
+
return torch.cuda.is_bf16_supported()
|
25
|
+
elif device.type == 'xpu' and hasattr(torch.xpu, 'is_bf16_supported'):
|
26
|
+
return torch.xpu.is_bf16_supported()
|
27
|
+
else:
|
28
|
+
return False
|
29
|
+
|
11
30
|
|
12
31
|
@torch.inference_mode()
|
13
32
|
def _kmeans_torch_double_chunked(
|
14
33
|
data: torch.Tensor,
|
15
34
|
data_norms: torch.Tensor,
|
16
35
|
k: int,
|
36
|
+
device: torch.device,
|
37
|
+
dtype: torch.dtype | None = None,
|
17
38
|
max_iters: int = 25,
|
18
39
|
tol: float = 1e-8,
|
19
|
-
device: str = None,
|
20
|
-
dtype: torch.dtype = None,
|
21
40
|
chunk_size_data: int = 50_000,
|
22
41
|
chunk_size_centroids: int = 10_000,
|
23
42
|
max_points_per_centroid: int = 256,
|
24
43
|
verbose: bool = False,
|
44
|
+
use_triton: bool | None = None,
|
25
45
|
):
|
26
46
|
"""
|
27
47
|
An efficient kmeans implementation that minimises OOM risks on modern hardware by using conversative double chunking.
|
@@ -32,6 +52,10 @@ def _kmeans_torch_double_chunked(
|
|
32
52
|
labels_cpu : torch.Tensor, shape (n_samples_used,), long
|
33
53
|
Where n_samples_used can be smaller than the original if subsampling occurred.
|
34
54
|
"""
|
55
|
+
|
56
|
+
if dtype is None:
|
57
|
+
dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
|
58
|
+
|
35
59
|
n_samples_original, n_features = data.shape
|
36
60
|
n_samples = n_samples_original
|
37
61
|
|
@@ -47,19 +71,16 @@ def _kmeans_torch_double_chunked(
|
|
47
71
|
if n_samples < k:
|
48
72
|
raise ValueError(f"Number of training points ({n_samples}) is less than k ({k}).")
|
49
73
|
|
50
|
-
if dtype is None: dtype = torch.float16 if device == 'cuda' else torch.float32
|
51
|
-
|
52
74
|
# centroid init -- random is the only supported init
|
53
75
|
rand_indices = torch.randperm(n_samples)[:k]
|
54
76
|
centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
|
55
77
|
prev_centroids = centroids.clone()
|
56
78
|
|
57
|
-
labels = torch.empty(n_samples, dtype=torch.
|
58
|
-
|
79
|
+
labels = torch.empty(n_samples, dtype=torch.int64, device='cpu') # Keep labels on CPU
|
59
80
|
|
60
81
|
for iteration in range(max_iters):
|
61
82
|
iteration_start_time = time.time()
|
62
|
-
|
83
|
+
|
63
84
|
centroid_norms = (centroids ** 2).sum(dim=1)
|
64
85
|
cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
|
65
86
|
cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
|
@@ -71,58 +92,62 @@ def _kmeans_torch_double_chunked(
|
|
71
92
|
data_chunk = data[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
|
72
93
|
data_chunk_norms = data_norms[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
|
73
94
|
batch_size = data_chunk.size(0)
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
85
|
-
dist_chunk = dist_chunk.addmm_(
|
86
|
-
data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
|
95
|
+
best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
|
96
|
+
|
97
|
+
if use_triton:
|
98
|
+
chunked_kmeans_kernel(
|
99
|
+
data_chunk=data_chunk,
|
100
|
+
data_chunk_norms=data_chunk_norms,
|
101
|
+
centroids=centroids,
|
102
|
+
centroids_sqnorm=centroid_norms,
|
103
|
+
best_ids=best_ids,
|
87
104
|
)
|
105
|
+
else:
|
106
|
+
best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
|
107
|
+
c_start = 0
|
108
|
+
while c_start < k:
|
109
|
+
c_end = min(c_start + chunk_size_centroids, k)
|
110
|
+
centroid_chunk = centroids[c_start:c_end]
|
111
|
+
centroid_chunk_norms = centroid_norms[c_start:c_end]
|
112
|
+
|
113
|
+
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
114
|
+
dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
|
88
115
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
116
|
+
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
117
|
+
improved_mask = local_min_vals < best_dist
|
118
|
+
best_dist[improved_mask] = local_min_vals[improved_mask]
|
119
|
+
best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
|
93
120
|
|
94
|
-
|
121
|
+
c_start = c_end
|
95
122
|
|
96
123
|
cluster_sums.index_add_(0, best_ids, data_chunk.float())
|
97
|
-
cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, dtype=torch.float32))
|
124
|
+
cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
|
98
125
|
|
99
126
|
labels[start_idx:end_idx] = best_ids.to('cpu', non_blocking=True)
|
100
127
|
start_idx = end_idx
|
101
128
|
|
102
|
-
new_centroids = torch.zeros_like(centroids, device=device, dtype=
|
129
|
+
new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
|
103
130
|
non_empty = (cluster_counts > 0)
|
104
|
-
new_centroids[non_empty] = (
|
105
|
-
cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)
|
106
|
-
)
|
131
|
+
new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
|
107
132
|
|
108
133
|
empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
|
109
134
|
if len(empty_ids) > 0:
|
110
135
|
reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device='cpu')
|
111
|
-
random_data = data[reinit_indices].to(device=device, dtype=
|
136
|
+
random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
|
112
137
|
new_centroids[empty_ids] = random_data
|
113
138
|
|
114
|
-
new_centroids = new_centroids.to(dtype=dtype)
|
115
|
-
|
116
139
|
shift = torch.norm(new_centroids - prev_centroids.to(new_centroids.device), dim=1).sum().item()
|
117
140
|
centroids = new_centroids
|
118
141
|
|
119
142
|
prev_centroids = centroids.clone()
|
120
|
-
|
143
|
+
|
121
144
|
iteration_time = time.time() - iteration_start_time
|
122
|
-
if verbose:
|
123
|
-
|
145
|
+
if verbose:
|
146
|
+
print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
|
147
|
+
|
124
148
|
if shift < tol:
|
125
|
-
if verbose:
|
149
|
+
if verbose:
|
150
|
+
print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
|
126
151
|
break
|
127
152
|
|
128
153
|
centroids_cpu = centroids.to('cpu', dtype=torch.float32)
|
@@ -155,6 +180,9 @@ class FastKMeans:
|
|
155
180
|
Chunk size along the data dimension for assignment/update steps.
|
156
181
|
chunk_size_centroids : int, default=10_000
|
157
182
|
Chunk size along the centroid dimension for assignment/update steps.
|
183
|
+
use_triton : bool | None, default=None
|
184
|
+
Use the fast Triton backend for the assignment/update steps.
|
185
|
+
If None, the Triton backend will be enabled for modern GPUs.
|
158
186
|
"""
|
159
187
|
|
160
188
|
def __init__(
|
@@ -168,33 +196,36 @@ class FastKMeans:
|
|
168
196
|
max_points_per_centroid: int = 256,
|
169
197
|
chunk_size_data: int = 50_000,
|
170
198
|
chunk_size_centroids: int = 10_000,
|
171
|
-
device: str = None,
|
199
|
+
device: str | int | torch.device | None = None,
|
172
200
|
dtype: torch.dtype = None,
|
173
201
|
pin_gpu_memory: bool = True,
|
174
202
|
verbose: bool = False,
|
175
203
|
nredo: int = 1, # for compatibility only
|
204
|
+
use_triton: bool | None = None,
|
176
205
|
):
|
177
206
|
self.d = d
|
178
207
|
self.k = k
|
179
208
|
self.niter = niter
|
180
209
|
self.tol = tol
|
181
|
-
self.gpu = gpu
|
182
210
|
self.seed = seed
|
183
211
|
self.max_points_per_centroid = max_points_per_centroid
|
184
212
|
self.chunk_size_data = chunk_size_data
|
185
213
|
self.chunk_size_centroids = chunk_size_centroids
|
214
|
+
self.device = _get_device("cpu" if gpu is False else device)
|
186
215
|
self.centroids = None
|
187
|
-
if device not in [None, 'cuda'] and self.gpu: print("Warning: device is set to 'cuda' but gpu is True, ignoring 'device' argument and setting it to 'cuda'!")
|
188
|
-
self.device = 'cuda' if self.gpu else device
|
189
216
|
self.dtype = dtype
|
190
217
|
self.pin_gpu_memory = pin_gpu_memory
|
191
|
-
if nredo != 1: raise ValueError("nredo must be 1, redos not currently supported")
|
192
218
|
self.verbose = verbose
|
219
|
+
if use_triton is not False:
|
220
|
+
use_triton = _is_bfloat16_supported(self.device) # assume triton is supported if GPU supports bfloat16
|
221
|
+
self.use_triton = use_triton
|
222
|
+
if nredo != 1:
|
223
|
+
raise ValueError("nredo must be 1, redos not currently supported")
|
193
224
|
|
194
225
|
def train(self, data: np.ndarray):
|
195
226
|
"""
|
196
227
|
Trains (fits) the KMeans model on the given data and sets `self.centroids`. Designed to mimic faiss's `train()` method.
|
197
|
-
|
228
|
+
|
198
229
|
Parameters
|
199
230
|
----------
|
200
231
|
data : np.ndarray of shape (n_samples, d), float32
|
@@ -224,6 +255,7 @@ class FastKMeans:
|
|
224
255
|
chunk_size_centroids=self.chunk_size_centroids,
|
225
256
|
max_points_per_centroid=self.max_points_per_centroid,
|
226
257
|
verbose=self.verbose,
|
258
|
+
use_triton=self.use_triton,
|
227
259
|
)
|
228
260
|
self.centroids = centroids.numpy()
|
229
261
|
|
@@ -250,11 +282,7 @@ class FastKMeans:
|
|
250
282
|
|
251
283
|
# We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
|
252
284
|
centroids_torch = torch.from_numpy(self.centroids)
|
253
|
-
|
254
|
-
if device == 'cpu' and self.gpu and torch.cuda.is_available():
|
255
|
-
device = 'cuda' # If user asked for GPU, put centroids there
|
256
|
-
|
257
|
-
centroids_torch = centroids_torch.to(device=device, dtype=torch.float32)
|
285
|
+
centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
|
258
286
|
centroid_norms = (centroids_torch ** 2).sum(dim=1)
|
259
287
|
|
260
288
|
n_samples = data_torch.shape[0]
|
@@ -263,30 +291,37 @@ class FastKMeans:
|
|
263
291
|
start_idx = 0
|
264
292
|
while start_idx < n_samples:
|
265
293
|
end_idx = min(start_idx + self.chunk_size_data, n_samples)
|
266
|
-
data_chunk = data_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
|
267
|
-
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
|
268
|
-
batch_size = data_chunk.size(0)
|
269
294
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
|
295
|
+
data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
|
296
|
+
data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
|
297
|
+
batch_size = data_chunk.size(0)
|
298
|
+
best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
|
299
|
+
|
300
|
+
if self.use_triton:
|
301
|
+
chunked_kmeans_kernel(
|
302
|
+
data_chunk,
|
303
|
+
data_chunk_norms,
|
304
|
+
centroids_torch,
|
305
|
+
centroid_norms,
|
306
|
+
best_ids,
|
283
307
|
)
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
308
|
+
else:
|
309
|
+
best_dist = torch.full((batch_size,), float('inf'), device=self.device, dtype=torch.float32)
|
310
|
+
c_start = 0
|
311
|
+
k = centroids_torch.shape[0]
|
312
|
+
while c_start < k:
|
313
|
+
c_end = min(c_start + self.chunk_size_centroids, k)
|
314
|
+
centroid_chunk = centroids_torch[c_start:c_end]
|
315
|
+
centroid_chunk_norms = centroid_norms[c_start:c_end]
|
316
|
+
|
317
|
+
dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
|
318
|
+
dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
|
319
|
+
|
320
|
+
local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
|
321
|
+
improved_mask = local_min_vals < best_dist
|
322
|
+
best_dist[improved_mask] = local_min_vals[improved_mask]
|
323
|
+
best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
|
324
|
+
c_start = c_end
|
290
325
|
|
291
326
|
labels[start_idx:end_idx] = best_ids.to('cpu')
|
292
327
|
start_idx = end_idx
|
@@ -0,0 +1,105 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.heuristics(
|
7
|
+
{
|
8
|
+
"BLOCK_M": lambda a: 128 if a["D"] <= 128 else (64 if a["D"] <= 512 else 32),
|
9
|
+
"BLOCK_N": lambda a: 64 if a["D"] <= 128 else 32,
|
10
|
+
"num_warps": lambda a: 4 if a["D"] <= 64 else 8,
|
11
|
+
"num_stages": lambda a: 2 if a["D"] < 64 else 1,
|
12
|
+
}
|
13
|
+
)
|
14
|
+
@triton.jit
|
15
|
+
def _chunked_kmeans_kernel(
|
16
|
+
data_ptr, # [B, D], row-major
|
17
|
+
x_norm_ptr, # [B], precomputed L2 norms of data
|
18
|
+
centroids_ptr, # [C, D], row-major
|
19
|
+
centroids_sqnorm_ptr, # [C], precomputed L2 norms of centroids
|
20
|
+
best_ids_ptr, # [B], int32 (to store best centroid indices)
|
21
|
+
B, # number of data points
|
22
|
+
C, # number of centroids
|
23
|
+
D: tl.constexpr, # dimension, or number of features
|
24
|
+
BLOCK_M: tl.constexpr,
|
25
|
+
BLOCK_N: tl.constexpr,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Each Triton block processes BLOCK_M rows of data. The kernel:
|
29
|
+
1) loads those rows (and their precomputed norms from x_norm_ptr),
|
30
|
+
2) loops over all centroids in chunks of BLOCK_N,
|
31
|
+
3) computes distances, finds the best centroid,
|
32
|
+
4) writes out the best centroid index for each data point.
|
33
|
+
"""
|
34
|
+
# 1) Identify which data rows this block handles
|
35
|
+
block_id = tl.program_id(axis=0)
|
36
|
+
row_start = block_id * BLOCK_M
|
37
|
+
rows = row_start + tl.arange(0, BLOCK_M)
|
38
|
+
mask = rows < B
|
39
|
+
|
40
|
+
# 2) Load data rows and precomputed x_norm: shape [BLOCK_M, D]
|
41
|
+
row_offsets = rows[:, None] * D + tl.arange(0, D)
|
42
|
+
x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
|
43
|
+
|
44
|
+
# shape: [BLOCK_M]
|
45
|
+
x_norm = tl.load(x_norm_ptr + rows, mask=mask, other=0.0)
|
46
|
+
|
47
|
+
# Prepare "best distance" + "best index"
|
48
|
+
best_dist = tl.full([BLOCK_M], 1e38, dtype=tl.float32)
|
49
|
+
best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
|
50
|
+
|
51
|
+
# 3) Iterate over the centroids in chunks of BLOCK_N
|
52
|
+
for chunk in range(0, C, BLOCK_N):
|
53
|
+
cids = chunk + tl.arange(0, BLOCK_N)
|
54
|
+
c_mask = cids < C
|
55
|
+
|
56
|
+
# Load sub-block of centroids: shape [BLOCK_N, D]
|
57
|
+
c_offsets = cids[:, None] * D + tl.arange(0, D)
|
58
|
+
cvals = tl.load(centroids_ptr + c_offsets, mask=c_mask[:, None], other=0.0).to(x.dtype)
|
59
|
+
|
60
|
+
# Load centroid norms: shape [BLOCK_N]
|
61
|
+
c_sqnorm = tl.load(centroids_sqnorm_ptr + cids, mask=c_mask, other=0.0).to(x.dtype)
|
62
|
+
|
63
|
+
# Compute distance = x_norm + c_sqnorm - 2 * dot(x, c)
|
64
|
+
dots = tl.dot(x, tl.trans(cvals)) # shape [BLOCK_M, BLOCK_N]
|
65
|
+
dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
|
66
|
+
|
67
|
+
# Find the argmin along the BLOCK_N dimension
|
68
|
+
local_min_vals, local_min_idx = tl.min(dist_chunk, axis=1, return_indices=True)
|
69
|
+
|
70
|
+
improved = local_min_vals < best_dist
|
71
|
+
best_dist = tl.where(improved, local_min_vals, best_dist)
|
72
|
+
best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
|
73
|
+
|
74
|
+
# 4) Write out the best centroid indices
|
75
|
+
tl.store(best_ids_ptr + rows, best_idx, mask=mask)
|
76
|
+
|
77
|
+
|
78
|
+
def chunked_kmeans_kernel(
|
79
|
+
data_chunk: torch.Tensor,
|
80
|
+
data_chunk_norms: torch.Tensor,
|
81
|
+
centroids: torch.Tensor,
|
82
|
+
centroids_sqnorm: torch.Tensor,
|
83
|
+
best_ids: torch.Tensor,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Launches the Triton kernel to assign each point to its nearest centroid in one pass.
|
87
|
+
|
88
|
+
best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
|
89
|
+
"""
|
90
|
+
B, D = data_chunk.shape
|
91
|
+
C = centroids.shape[0]
|
92
|
+
|
93
|
+
def grid(meta):
|
94
|
+
return (triton.cdiv(B, meta["BLOCK_M"]),)
|
95
|
+
|
96
|
+
_chunked_kmeans_kernel[grid](
|
97
|
+
data_chunk,
|
98
|
+
data_chunk_norms,
|
99
|
+
centroids,
|
100
|
+
centroids_sqnorm,
|
101
|
+
best_ids,
|
102
|
+
B, # num_points
|
103
|
+
C, # num_centroids
|
104
|
+
D, # dimension, or number of features
|
105
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: fastkmeans
|
3
|
+
Version: 0.2.0
|
4
|
+
Summary: Add your description here
|
5
|
+
Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
6
|
+
Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
|
7
|
+
License: Apache-2.0
|
8
|
+
Requires-Python: >=3.8
|
9
|
+
Description-Content-Type: text/markdown
|
10
|
+
License-File: LICENSE
|
11
|
+
Requires-Dist: torch
|
12
|
+
Requires-Dist: numpy
|
13
|
+
Dynamic: license-file
|
14
|
+
|
15
|
+
# fastkmeans
|
16
|
+
|
17
|
+

|
18
|
+
[](https://twitter.com/bclavie)
|
19
|
+
<!-- [](https://pepy.tech/project/fastkmeans) -->
|
20
|
+
|
21
|
+
_A fast and efficient k-means implementation for PyTorch, with support for GPU and CPU._
|
22
|
+
|
23
|
+
---
|
24
|
+
|
25
|
+
Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" pytorch native k-means clustering. It's compatible with both CPU and GPU, matching or outspeeding `faiss` (except in multi-GPU settings), and it comes without any install woes, relying on just two dependencies you already have installed anyway: `torch` and `numpy`.
|
26
|
+
|
27
|
+
### Get started
|
28
|
+
|
29
|
+
```sh
|
30
|
+
[uv] pip install fastkmeans
|
31
|
+
```
|
32
|
+
|
33
|
+
... and that's all you need to do! `FastKMeans` is now ready to use.
|
34
|
+
|
35
|
+
### So what does this do?
|
36
|
+
|
37
|
+
There's very, very little to this library. It provides a single interface, `FastKMeans`, which you can use by importing it from `fastkmeans`. This interface has been designed to be a slot-in replacement for existing FAISS implementations, while being *mostly* sklearn-compliant as well. Effectively, this means that three methods are exposed:
|
38
|
+
|
39
|
+
- train(): mimics the FAISS API, training the model on a dataset.
|
40
|
+
- fit(): mimics the sklearn API, training the model on a dataset.
|
41
|
+
- predict(): mimics the sklearn API, use the trained clusters to predict where new points belong.
|
42
|
+
- fit_predict(): mimics the sklearn API, chaining the two calls above.
|
43
|
+
|
44
|
+
#### Behaviour
|
45
|
+
|
46
|
+
Whenever possible, the library will attempt to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see what the arguments are, as they're very straightforward. The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
|
47
|
+
|
48
|
+
#### Chunking
|
49
|
+
|
50
|
+
The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
|
51
|
+
|
52
|
+
### Why `fastkmeans`?
|
53
|
+
|
54
|
+
The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
|
55
|
+
|
56
|
+
- `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
|
57
|
+
- `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch: in fact, it's faster than this library! However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
|
58
|
+
- `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
|
59
|
+
|
60
|
+
There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
|
61
|
+
|
62
|
+
### Limitations
|
63
|
+
|
64
|
+
- On a few toy datasets & MNIST, `fastkmeans` reaches roughly the same NMI as `faiss` and `scikit-learn`, indicating that it creates at least somewhat coherent clusters. However, it's not extensively tested, especially in non-ColBERT uses, so your mileage may vary.
|
65
|
+
- The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
|
66
|
+
- The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
|
67
|
+
|
68
|
+
### Speed
|
69
|
+
|
70
|
+
Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, 131072 and 262144 clusters, each with w/ cluster_size*100 data points).
|
71
|
+
|
72
|
+

|
@@ -0,0 +1,8 @@
|
|
1
|
+
fastkmeans/__init__.py,sha256=mPLGfhksBBfB1dTkPBPrPg5E1qixixl_Qritc6A10AI,79
|
2
|
+
fastkmeans/kmeans.py,sha256=RLDeaCIbKpCJBzV2XO3JPgLwjrhJ6vtA997Gyzm_GyA,13651
|
3
|
+
fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
|
4
|
+
fastkmeans-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
+
fastkmeans-0.2.0.dist-info/METADATA,sha256=2hsTZr0t2_CNhCyrECIBk2S4ZB7ytF0YaDqgcgHDvDc,6791
|
6
|
+
fastkmeans-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
7
|
+
fastkmeans-0.2.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
+
fastkmeans-0.2.0.dist-info/RECORD,,
|
@@ -1,13 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: fastkmeans
|
3
|
-
Version: 0.1.0
|
4
|
-
Summary: Add your description here
|
5
|
-
Author-email: Ben Clavié <bc@answer.ai>
|
6
|
-
Maintainer-email: Ben Clavié <bc@answer.ai>
|
7
|
-
License: Apache-2.0
|
8
|
-
Requires-Python: >=3.8
|
9
|
-
Description-Content-Type: text/markdown
|
10
|
-
License-File: LICENSE
|
11
|
-
Requires-Dist: torch
|
12
|
-
Requires-Dist: numpy
|
13
|
-
Dynamic: license-file
|
@@ -1,7 +0,0 @@
|
|
1
|
-
fastkmeans/__init__.py,sha256=eIDvwiAarlwyo4n6AvHsV3cPqkbLuILUE1gFAjpzItY,79
|
2
|
-
fastkmeans/kmeans.py,sha256=HO44oEOyI8esRDpUJYIq6tHXXVgysSRazoMSIJ2HkX8,12016
|
3
|
-
fastkmeans-0.1.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
-
fastkmeans-0.1.0.dist-info/METADATA,sha256=Fy5bypSuzqwdPs-Qw2fTJ8RZn2b9AKj9POyvv4U7gkk,344
|
5
|
-
fastkmeans-0.1.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
6
|
-
fastkmeans-0.1.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
7
|
-
fastkmeans-0.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|