fastkmeans 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastkmeans/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .kmeans import FastKMeans
2
2
 
3
3
  __all__ = ["FastKMeans"]
4
- __version__ = "0.1.0"
4
+ __version__ = "0.2.0"
fastkmeans/kmeans.py CHANGED
@@ -3,25 +3,45 @@ import time
3
3
  import torch
4
4
  import numpy as np
5
5
 
6
- def _get_device(preset: str = None):
7
- if preset: return preset
8
- if torch.cuda.is_available(): return 'cuda'
9
- if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return 'mps'
10
- return 'cpu'
6
+ from fastkmeans.triton_kernels import chunked_kmeans_kernel
7
+
8
+ def _get_device(preset: str | int | torch.device | None = None):
9
+ if isinstance(preset, torch.device):
10
+ return preset
11
+ if isinstance(preset, str):
12
+ return torch.device(preset)
13
+ if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
14
+ return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
15
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
16
+ return torch.device('mps')
17
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
18
+ return torch.device(f'xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}')
19
+ return torch.device('cpu')
20
+
21
+
22
+ def _is_bfloat16_supported(device:torch.device):
23
+ if device.type == 'cuda':
24
+ return torch.cuda.is_bf16_supported()
25
+ elif device.type == 'xpu' and hasattr(torch.xpu, 'is_bf16_supported'):
26
+ return torch.xpu.is_bf16_supported()
27
+ else:
28
+ return False
29
+
11
30
 
12
31
  @torch.inference_mode()
13
32
  def _kmeans_torch_double_chunked(
14
33
  data: torch.Tensor,
15
34
  data_norms: torch.Tensor,
16
35
  k: int,
36
+ device: torch.device,
37
+ dtype: torch.dtype | None = None,
17
38
  max_iters: int = 25,
18
39
  tol: float = 1e-8,
19
- device: str = None,
20
- dtype: torch.dtype = None,
21
40
  chunk_size_data: int = 50_000,
22
41
  chunk_size_centroids: int = 10_000,
23
42
  max_points_per_centroid: int = 256,
24
43
  verbose: bool = False,
44
+ use_triton: bool | None = None,
25
45
  ):
26
46
  """
27
47
  An efficient kmeans implementation that minimises OOM risks on modern hardware by using conversative double chunking.
@@ -32,6 +52,10 @@ def _kmeans_torch_double_chunked(
32
52
  labels_cpu : torch.Tensor, shape (n_samples_used,), long
33
53
  Where n_samples_used can be smaller than the original if subsampling occurred.
34
54
  """
55
+
56
+ if dtype is None:
57
+ dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
58
+
35
59
  n_samples_original, n_features = data.shape
36
60
  n_samples = n_samples_original
37
61
 
@@ -47,19 +71,16 @@ def _kmeans_torch_double_chunked(
47
71
  if n_samples < k:
48
72
  raise ValueError(f"Number of training points ({n_samples}) is less than k ({k}).")
49
73
 
50
- if dtype is None: dtype = torch.float16 if device == 'cuda' else torch.float32
51
-
52
74
  # centroid init -- random is the only supported init
53
75
  rand_indices = torch.randperm(n_samples)[:k]
54
76
  centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
55
77
  prev_centroids = centroids.clone()
56
78
 
57
- labels = torch.empty(n_samples, dtype=torch.long, device='cpu') # Keep labels on CPU
58
-
79
+ labels = torch.empty(n_samples, dtype=torch.int64, device='cpu') # Keep labels on CPU
59
80
 
60
81
  for iteration in range(max_iters):
61
82
  iteration_start_time = time.time()
62
-
83
+
63
84
  centroid_norms = (centroids ** 2).sum(dim=1)
64
85
  cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
65
86
  cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
@@ -71,58 +92,62 @@ def _kmeans_torch_double_chunked(
71
92
  data_chunk = data[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
72
93
  data_chunk_norms = data_norms[start_idx:end_idx].to(device=device, dtype=dtype, non_blocking=True)
73
94
  batch_size = data_chunk.size(0)
74
-
75
- best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
76
- best_ids = torch.zeros((batch_size,), device=device, dtype=torch.long)
77
-
78
- c_start = 0
79
- while c_start < k:
80
- c_end = min(c_start + chunk_size_centroids, k)
81
- centroid_chunk = centroids[c_start:c_end]
82
- centroid_chunk_norms = centroid_norms[c_start:c_end]
83
-
84
- dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
85
- dist_chunk = dist_chunk.addmm_(
86
- data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
95
+ best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
96
+
97
+ if use_triton:
98
+ chunked_kmeans_kernel(
99
+ data_chunk=data_chunk,
100
+ data_chunk_norms=data_chunk_norms,
101
+ centroids=centroids,
102
+ centroids_sqnorm=centroid_norms,
103
+ best_ids=best_ids,
87
104
  )
105
+ else:
106
+ best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
107
+ c_start = 0
108
+ while c_start < k:
109
+ c_end = min(c_start + chunk_size_centroids, k)
110
+ centroid_chunk = centroids[c_start:c_end]
111
+ centroid_chunk_norms = centroid_norms[c_start:c_end]
112
+
113
+ dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
114
+ dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
88
115
 
89
- local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
90
- improved_mask = local_min_vals < best_dist
91
- best_dist[improved_mask] = local_min_vals[improved_mask]
92
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
116
+ local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
117
+ improved_mask = local_min_vals < best_dist
118
+ best_dist[improved_mask] = local_min_vals[improved_mask]
119
+ best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
93
120
 
94
- c_start = c_end
121
+ c_start = c_end
95
122
 
96
123
  cluster_sums.index_add_(0, best_ids, data_chunk.float())
97
- cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, dtype=torch.float32))
124
+ cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
98
125
 
99
126
  labels[start_idx:end_idx] = best_ids.to('cpu', non_blocking=True)
100
127
  start_idx = end_idx
101
128
 
102
- new_centroids = torch.zeros_like(centroids, device=device, dtype=torch.float32)
129
+ new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
103
130
  non_empty = (cluster_counts > 0)
104
- new_centroids[non_empty] = (
105
- cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)
106
- )
131
+ new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
107
132
 
108
133
  empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
109
134
  if len(empty_ids) > 0:
110
135
  reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device='cpu')
111
- random_data = data[reinit_indices].to(device=device, dtype=torch.float32)
136
+ random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
112
137
  new_centroids[empty_ids] = random_data
113
138
 
114
- new_centroids = new_centroids.to(dtype=dtype)
115
-
116
139
  shift = torch.norm(new_centroids - prev_centroids.to(new_centroids.device), dim=1).sum().item()
117
140
  centroids = new_centroids
118
141
 
119
142
  prev_centroids = centroids.clone()
120
-
143
+
121
144
  iteration_time = time.time() - iteration_start_time
122
- if verbose: print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
123
-
145
+ if verbose:
146
+ print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
147
+
124
148
  if shift < tol:
125
- if verbose: print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
149
+ if verbose:
150
+ print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
126
151
  break
127
152
 
128
153
  centroids_cpu = centroids.to('cpu', dtype=torch.float32)
@@ -155,6 +180,9 @@ class FastKMeans:
155
180
  Chunk size along the data dimension for assignment/update steps.
156
181
  chunk_size_centroids : int, default=10_000
157
182
  Chunk size along the centroid dimension for assignment/update steps.
183
+ use_triton : bool | None, default=None
184
+ Use the fast Triton backend for the assignment/update steps.
185
+ If None, the Triton backend will be enabled for modern GPUs.
158
186
  """
159
187
 
160
188
  def __init__(
@@ -168,33 +196,36 @@ class FastKMeans:
168
196
  max_points_per_centroid: int = 256,
169
197
  chunk_size_data: int = 50_000,
170
198
  chunk_size_centroids: int = 10_000,
171
- device: str = None,
199
+ device: str | int | torch.device | None = None,
172
200
  dtype: torch.dtype = None,
173
201
  pin_gpu_memory: bool = True,
174
202
  verbose: bool = False,
175
203
  nredo: int = 1, # for compatibility only
204
+ use_triton: bool | None = None,
176
205
  ):
177
206
  self.d = d
178
207
  self.k = k
179
208
  self.niter = niter
180
209
  self.tol = tol
181
- self.gpu = gpu
182
210
  self.seed = seed
183
211
  self.max_points_per_centroid = max_points_per_centroid
184
212
  self.chunk_size_data = chunk_size_data
185
213
  self.chunk_size_centroids = chunk_size_centroids
214
+ self.device = _get_device("cpu" if gpu is False else device)
186
215
  self.centroids = None
187
- if device not in [None, 'cuda'] and self.gpu: print("Warning: device is set to 'cuda' but gpu is True, ignoring 'device' argument and setting it to 'cuda'!")
188
- self.device = 'cuda' if self.gpu else device
189
216
  self.dtype = dtype
190
217
  self.pin_gpu_memory = pin_gpu_memory
191
- if nredo != 1: raise ValueError("nredo must be 1, redos not currently supported")
192
218
  self.verbose = verbose
219
+ if use_triton is not False:
220
+ use_triton = _is_bfloat16_supported(self.device) # assume triton is supported if GPU supports bfloat16
221
+ self.use_triton = use_triton
222
+ if nredo != 1:
223
+ raise ValueError("nredo must be 1, redos not currently supported")
193
224
 
194
225
  def train(self, data: np.ndarray):
195
226
  """
196
227
  Trains (fits) the KMeans model on the given data and sets `self.centroids`. Designed to mimic faiss's `train()` method.
197
-
228
+
198
229
  Parameters
199
230
  ----------
200
231
  data : np.ndarray of shape (n_samples, d), float32
@@ -224,6 +255,7 @@ class FastKMeans:
224
255
  chunk_size_centroids=self.chunk_size_centroids,
225
256
  max_points_per_centroid=self.max_points_per_centroid,
226
257
  verbose=self.verbose,
258
+ use_triton=self.use_triton,
227
259
  )
228
260
  self.centroids = centroids.numpy()
229
261
 
@@ -250,11 +282,7 @@ class FastKMeans:
250
282
 
251
283
  # We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
252
284
  centroids_torch = torch.from_numpy(self.centroids)
253
- device = centroids_torch.device.type
254
- if device == 'cpu' and self.gpu and torch.cuda.is_available():
255
- device = 'cuda' # If user asked for GPU, put centroids there
256
-
257
- centroids_torch = centroids_torch.to(device=device, dtype=torch.float32)
285
+ centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
258
286
  centroid_norms = (centroids_torch ** 2).sum(dim=1)
259
287
 
260
288
  n_samples = data_torch.shape[0]
@@ -263,30 +291,37 @@ class FastKMeans:
263
291
  start_idx = 0
264
292
  while start_idx < n_samples:
265
293
  end_idx = min(start_idx + self.chunk_size_data, n_samples)
266
- data_chunk = data_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
267
- data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=device, dtype=torch.float32, non_blocking=True)
268
- batch_size = data_chunk.size(0)
269
294
 
270
- best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=torch.float32)
271
- best_ids = torch.zeros((batch_size,), device=device, dtype=torch.long)
272
-
273
- c_start = 0
274
- k = centroids_torch.shape[0]
275
- while c_start < k:
276
- c_end = min(c_start + self.chunk_size_centroids, k)
277
- centroid_chunk = centroids_torch[c_start:c_end]
278
- centroid_chunk_norms = centroid_norms[c_start:c_end]
279
-
280
- dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
281
- dist_chunk = dist_chunk.addmm_(
282
- data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0
295
+ data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
296
+ data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
297
+ batch_size = data_chunk.size(0)
298
+ best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
299
+
300
+ if self.use_triton:
301
+ chunked_kmeans_kernel(
302
+ data_chunk,
303
+ data_chunk_norms,
304
+ centroids_torch,
305
+ centroid_norms,
306
+ best_ids,
283
307
  )
284
-
285
- local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
286
- improved_mask = local_min_vals < best_dist
287
- best_dist[improved_mask] = local_min_vals[improved_mask]
288
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
289
- c_start = c_end
308
+ else:
309
+ best_dist = torch.full((batch_size,), float('inf'), device=self.device, dtype=torch.float32)
310
+ c_start = 0
311
+ k = centroids_torch.shape[0]
312
+ while c_start < k:
313
+ c_end = min(c_start + self.chunk_size_centroids, k)
314
+ centroid_chunk = centroids_torch[c_start:c_end]
315
+ centroid_chunk_norms = centroid_norms[c_start:c_end]
316
+
317
+ dist_chunk = data_chunk_norms.unsqueeze(1) + centroid_chunk_norms.unsqueeze(0)
318
+ dist_chunk = dist_chunk.addmm_(data_chunk, centroid_chunk.t(), alpha=-2.0, beta=1.0)
319
+
320
+ local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
321
+ improved_mask = local_min_vals < best_dist
322
+ best_dist[improved_mask] = local_min_vals[improved_mask]
323
+ best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
324
+ c_start = c_end
290
325
 
291
326
  labels[start_idx:end_idx] = best_ids.to('cpu')
292
327
  start_idx = end_idx
@@ -0,0 +1,105 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.heuristics(
7
+ {
8
+ "BLOCK_M": lambda a: 128 if a["D"] <= 128 else (64 if a["D"] <= 512 else 32),
9
+ "BLOCK_N": lambda a: 64 if a["D"] <= 128 else 32,
10
+ "num_warps": lambda a: 4 if a["D"] <= 64 else 8,
11
+ "num_stages": lambda a: 2 if a["D"] < 64 else 1,
12
+ }
13
+ )
14
+ @triton.jit
15
+ def _chunked_kmeans_kernel(
16
+ data_ptr, # [B, D], row-major
17
+ x_norm_ptr, # [B], precomputed L2 norms of data
18
+ centroids_ptr, # [C, D], row-major
19
+ centroids_sqnorm_ptr, # [C], precomputed L2 norms of centroids
20
+ best_ids_ptr, # [B], int32 (to store best centroid indices)
21
+ B, # number of data points
22
+ C, # number of centroids
23
+ D: tl.constexpr, # dimension, or number of features
24
+ BLOCK_M: tl.constexpr,
25
+ BLOCK_N: tl.constexpr,
26
+ ):
27
+ """
28
+ Each Triton block processes BLOCK_M rows of data. The kernel:
29
+ 1) loads those rows (and their precomputed norms from x_norm_ptr),
30
+ 2) loops over all centroids in chunks of BLOCK_N,
31
+ 3) computes distances, finds the best centroid,
32
+ 4) writes out the best centroid index for each data point.
33
+ """
34
+ # 1) Identify which data rows this block handles
35
+ block_id = tl.program_id(axis=0)
36
+ row_start = block_id * BLOCK_M
37
+ rows = row_start + tl.arange(0, BLOCK_M)
38
+ mask = rows < B
39
+
40
+ # 2) Load data rows and precomputed x_norm: shape [BLOCK_M, D]
41
+ row_offsets = rows[:, None] * D + tl.arange(0, D)
42
+ x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
43
+
44
+ # shape: [BLOCK_M]
45
+ x_norm = tl.load(x_norm_ptr + rows, mask=mask, other=0.0)
46
+
47
+ # Prepare "best distance" + "best index"
48
+ best_dist = tl.full([BLOCK_M], 1e38, dtype=tl.float32)
49
+ best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
50
+
51
+ # 3) Iterate over the centroids in chunks of BLOCK_N
52
+ for chunk in range(0, C, BLOCK_N):
53
+ cids = chunk + tl.arange(0, BLOCK_N)
54
+ c_mask = cids < C
55
+
56
+ # Load sub-block of centroids: shape [BLOCK_N, D]
57
+ c_offsets = cids[:, None] * D + tl.arange(0, D)
58
+ cvals = tl.load(centroids_ptr + c_offsets, mask=c_mask[:, None], other=0.0).to(x.dtype)
59
+
60
+ # Load centroid norms: shape [BLOCK_N]
61
+ c_sqnorm = tl.load(centroids_sqnorm_ptr + cids, mask=c_mask, other=0.0).to(x.dtype)
62
+
63
+ # Compute distance = x_norm + c_sqnorm - 2 * dot(x, c)
64
+ dots = tl.dot(x, tl.trans(cvals)) # shape [BLOCK_M, BLOCK_N]
65
+ dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
66
+
67
+ # Find the argmin along the BLOCK_N dimension
68
+ local_min_vals, local_min_idx = tl.min(dist_chunk, axis=1, return_indices=True)
69
+
70
+ improved = local_min_vals < best_dist
71
+ best_dist = tl.where(improved, local_min_vals, best_dist)
72
+ best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
73
+
74
+ # 4) Write out the best centroid indices
75
+ tl.store(best_ids_ptr + rows, best_idx, mask=mask)
76
+
77
+
78
+ def chunked_kmeans_kernel(
79
+ data_chunk: torch.Tensor,
80
+ data_chunk_norms: torch.Tensor,
81
+ centroids: torch.Tensor,
82
+ centroids_sqnorm: torch.Tensor,
83
+ best_ids: torch.Tensor,
84
+ ):
85
+ """
86
+ Launches the Triton kernel to assign each point to its nearest centroid in one pass.
87
+
88
+ best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
89
+ """
90
+ B, D = data_chunk.shape
91
+ C = centroids.shape[0]
92
+
93
+ def grid(meta):
94
+ return (triton.cdiv(B, meta["BLOCK_M"]),)
95
+
96
+ _chunked_kmeans_kernel[grid](
97
+ data_chunk,
98
+ data_chunk_norms,
99
+ centroids,
100
+ centroids_sqnorm,
101
+ best_ids,
102
+ B, # num_points
103
+ C, # num_centroids
104
+ D, # dimension, or number of features
105
+ )
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.4
2
+ Name: fastkmeans
3
+ Version: 0.2.0
4
+ Summary: Add your description here
5
+ Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
6
+ Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
7
+ License: Apache-2.0
8
+ Requires-Python: >=3.8
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch
12
+ Requires-Dist: numpy
13
+ Dynamic: license-file
14
+
15
+ # fastkmeans
16
+
17
+ ![Python Versions](https://img.shields.io/badge/Python-3.8_3.9_3.10_3.11_3.12_3.13-blue)
18
+ [![Twitter Follow](https://img.shields.io/twitter/follow/bclavie?style=social)](https://twitter.com/bclavie)
19
+ <!-- [![Downloads](https://static.pepy.tech/badge/fastkmeans/month)](https://pepy.tech/project/fastkmeans) -->
20
+
21
+ _A fast and efficient k-means implementation for PyTorch, with support for GPU and CPU._
22
+
23
+ ---
24
+
25
+ Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" pytorch native k-means clustering. It's compatible with both CPU and GPU, matching or outspeeding `faiss` (except in multi-GPU settings), and it comes without any install woes, relying on just two dependencies you already have installed anyway: `torch` and `numpy`.
26
+
27
+ ### Get started
28
+
29
+ ```sh
30
+ [uv] pip install fastkmeans
31
+ ```
32
+
33
+ ... and that's all you need to do! `FastKMeans` is now ready to use.
34
+
35
+ ### So what does this do?
36
+
37
+ There's very, very little to this library. It provides a single interface, `FastKMeans`, which you can use by importing it from `fastkmeans`. This interface has been designed to be a slot-in replacement for existing FAISS implementations, while being *mostly* sklearn-compliant as well. Effectively, this means that three methods are exposed:
38
+
39
+ - train(): mimics the FAISS API, training the model on a dataset.
40
+ - fit(): mimics the sklearn API, training the model on a dataset.
41
+ - predict(): mimics the sklearn API, use the trained clusters to predict where new points belong.
42
+ - fit_predict(): mimics the sklearn API, chaining the two calls above.
43
+
44
+ #### Behaviour
45
+
46
+ Whenever possible, the library will attempt to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see what the arguments are, as they're very straightforward. The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
47
+
48
+ #### Chunking
49
+
50
+ The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
51
+
52
+ ### Why `fastkmeans`?
53
+
54
+ The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
55
+
56
+ - `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
57
+ - `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch: in fact, it's faster than this library! However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
58
+ - `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
59
+
60
+ There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
61
+
62
+ ### Limitations
63
+
64
+ - On a few toy datasets & MNIST, `fastkmeans` reaches roughly the same NMI as `faiss` and `scikit-learn`, indicating that it creates at least somewhat coherent clusters. However, it's not extensively tested, especially in non-ColBERT uses, so your mileage may vary.
65
+ - The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
66
+ - The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
67
+
68
+ ### Speed
69
+
70
+ Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, 131072 and 262144 clusters, each with w/ cluster_size*100 data points).
71
+
72
+ ![fastkmeans benchmark](./benchmark_plots/4090_benchmark.png)
@@ -0,0 +1,8 @@
1
+ fastkmeans/__init__.py,sha256=mPLGfhksBBfB1dTkPBPrPg5E1qixixl_Qritc6A10AI,79
2
+ fastkmeans/kmeans.py,sha256=RLDeaCIbKpCJBzV2XO3JPgLwjrhJ6vtA997Gyzm_GyA,13651
3
+ fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
4
+ fastkmeans-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
+ fastkmeans-0.2.0.dist-info/METADATA,sha256=2hsTZr0t2_CNhCyrECIBk2S4ZB7ytF0YaDqgcgHDvDc,6791
6
+ fastkmeans-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
7
+ fastkmeans-0.2.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
8
+ fastkmeans-0.2.0.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: fastkmeans
3
- Version: 0.1.0
4
- Summary: Add your description here
5
- Author-email: Ben Clavié <bc@answer.ai>
6
- Maintainer-email: Ben Clavié <bc@answer.ai>
7
- License: Apache-2.0
8
- Requires-Python: >=3.8
9
- Description-Content-Type: text/markdown
10
- License-File: LICENSE
11
- Requires-Dist: torch
12
- Requires-Dist: numpy
13
- Dynamic: license-file
@@ -1,7 +0,0 @@
1
- fastkmeans/__init__.py,sha256=eIDvwiAarlwyo4n6AvHsV3cPqkbLuILUE1gFAjpzItY,79
2
- fastkmeans/kmeans.py,sha256=HO44oEOyI8esRDpUJYIq6tHXXVgysSRazoMSIJ2HkX8,12016
3
- fastkmeans-0.1.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
- fastkmeans-0.1.0.dist-info/METADATA,sha256=Fy5bypSuzqwdPs-Qw2fTJ8RZn2b9AKj9POyvv4U7gkk,344
5
- fastkmeans-0.1.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
6
- fastkmeans-0.1.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
7
- fastkmeans-0.1.0.dist-info/RECORD,,