fastkmeans 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastkmeans/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .kmeans import FastKMeans
2
2
 
3
3
  __all__ = ["FastKMeans"]
4
- __version__ = "0.3.0"
4
+ __version__ = "0.5.0"
fastkmeans/kmeans.py CHANGED
@@ -1,26 +1,37 @@
1
+ from __future__ import annotations
2
+
1
3
  import time
2
4
 
3
5
  import torch
4
6
  import numpy as np
5
7
 
8
+ try:
9
+ from fastkmeans.triton_kernels import triton_kmeans
10
+
11
+ HAS_TRITON = True
12
+ except ImportError:
13
+ triton_kmeans = None
14
+ HAS_TRITON = False
15
+
16
+
6
17
  def _get_device(preset: str | int | torch.device | None = None):
7
18
  if isinstance(preset, torch.device):
8
19
  return preset
9
20
  if isinstance(preset, str):
10
21
  return torch.device(preset)
11
- if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
22
+ if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
12
23
  return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
13
- if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
14
- return torch.device('mps')
15
- if hasattr(torch, 'xpu') and torch.xpu.is_available():
16
- return torch.device(f'xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}')
17
- return torch.device('cpu')
24
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
25
+ return torch.device("mps")
26
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
27
+ return torch.device(f"xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}")
28
+ return torch.device("cpu")
18
29
 
19
30
 
20
- def _is_bfloat16_supported(device:torch.device):
21
- if device.type == 'cuda':
31
+ def _is_bfloat16_supported(device: torch.device):
32
+ if device.type == "cuda":
22
33
  return torch.cuda.is_bf16_supported()
23
- elif device.type == 'xpu' and hasattr(torch.xpu, 'is_bf16_supported'):
34
+ elif device.type == "xpu" and hasattr(torch.xpu, "is_bf16_supported"):
24
35
  return torch.xpu.is_bf16_supported()
25
36
  else:
26
37
  return False
@@ -52,10 +63,11 @@ def _kmeans_torch_double_chunked(
52
63
  """
53
64
 
54
65
  if use_triton:
55
- from fastkmeans.triton_kernels import chunked_kmeans_kernel
66
+ if not HAS_TRITON:
67
+ raise ImportError("Triton is not available. Please install Triton and try again.")
56
68
 
57
69
  if dtype is None:
58
- dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
70
+ dtype = torch.float16 if device.type in ["cuda", "xpu"] else torch.float32
59
71
 
60
72
  n_samples_original, n_features = data.shape
61
73
  n_samples = n_samples_original
@@ -77,12 +89,12 @@ def _kmeans_torch_double_chunked(
77
89
  centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
78
90
  prev_centroids = centroids.clone()
79
91
 
80
- labels = torch.empty(n_samples, dtype=torch.int64, device='cpu') # Keep labels on CPU
92
+ labels = torch.empty(n_samples, dtype=torch.int64, device="cpu") # Keep labels on CPU
81
93
 
82
94
  for iteration in range(max_iters):
83
95
  iteration_start_time = time.time()
84
96
 
85
- centroid_norms = (centroids ** 2).sum(dim=1)
97
+ centroid_norms = (centroids**2).sum(dim=1)
86
98
  cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
87
99
  cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
88
100
 
@@ -96,7 +108,7 @@ def _kmeans_torch_double_chunked(
96
108
  best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
97
109
 
98
110
  if use_triton:
99
- chunked_kmeans_kernel(
111
+ triton_kmeans(
100
112
  data_chunk=data_chunk,
101
113
  data_chunk_norms=data_chunk_norms,
102
114
  centroids=centroids,
@@ -104,7 +116,7 @@ def _kmeans_torch_double_chunked(
104
116
  best_ids=best_ids,
105
117
  )
106
118
  else:
107
- best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
119
+ best_dist = torch.full((batch_size,), float("inf"), device=device, dtype=dtype)
108
120
  c_start = 0
109
121
  while c_start < k:
110
122
  c_end = min(c_start + chunk_size_centroids, k)
@@ -117,23 +129,23 @@ def _kmeans_torch_double_chunked(
117
129
  local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
118
130
  improved_mask = local_min_vals < best_dist
119
131
  best_dist[improved_mask] = local_min_vals[improved_mask]
120
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
132
+ best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
121
133
 
122
134
  c_start = c_end
123
135
 
124
136
  cluster_sums.index_add_(0, best_ids, data_chunk.float())
125
137
  cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
126
138
 
127
- labels[start_idx:end_idx] = best_ids.to('cpu', non_blocking=True)
139
+ labels[start_idx:end_idx] = best_ids.to("cpu", non_blocking=True)
128
140
  start_idx = end_idx
129
141
 
130
142
  new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
131
- non_empty = (cluster_counts > 0)
143
+ non_empty = cluster_counts > 0
132
144
  new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
133
145
 
134
146
  empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
135
147
  if len(empty_ids) > 0:
136
- reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device='cpu')
148
+ reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device="cpu")
137
149
  random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
138
150
  new_centroids[empty_ids] = random_data
139
151
 
@@ -144,14 +156,16 @@ def _kmeans_torch_double_chunked(
144
156
 
145
157
  iteration_time = time.time() - iteration_start_time
146
158
  if verbose:
147
- print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
159
+ print(
160
+ f"Iteration {iteration + 1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}"
161
+ )
148
162
 
149
163
  if shift < tol:
150
164
  if verbose:
151
- print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
165
+ print(f"Converged after {iteration + 1} iterations (shift: {shift:.6f} < tol: {tol})")
152
166
  break
153
167
 
154
- centroids_cpu = centroids.to('cpu', dtype=torch.float32)
168
+ centroids_cpu = centroids.to("cpu", dtype=torch.float32)
155
169
  return centroids_cpu, labels
156
170
 
157
171
 
@@ -177,9 +191,9 @@ class FastKMeans:
177
191
  max_points_per_centroid : int, optional, default=1_000_000_000
178
192
  If n_samples > k * max_points_per_centroid, the data will be subsampled to exactly
179
193
  k * max_points_per_centroid points before clustering.
180
- chunk_size_data : int, default=50_000
194
+ chunk_size_data : int, default=10,2400
181
195
  Chunk size along the data dimension for assignment/update steps.
182
- chunk_size_centroids : int, default=10_000
196
+ chunk_size_centroids : int, default=10,240
183
197
  Chunk size along the centroid dimension for assignment/update steps.
184
198
  use_triton : bool | None, default=None
185
199
  Use the fast Triton backend for the assignment/update steps.
@@ -194,14 +208,14 @@ class FastKMeans:
194
208
  tol: float = 1e-8,
195
209
  gpu: bool = True,
196
210
  seed: int = 0,
197
- max_points_per_centroid: int = 256,
198
- chunk_size_data: int = 50_000,
199
- chunk_size_centroids: int = 10_000,
211
+ max_points_per_centroid: int = 1_000_000_000,
212
+ chunk_size_data: int = 51_200,
213
+ chunk_size_centroids: int = 10_240,
200
214
  device: str | int | torch.device | None = None,
201
215
  dtype: torch.dtype = None,
202
216
  pin_gpu_memory: bool = True,
203
217
  verbose: bool = False,
204
- nredo: int = 1, # for compatibility only
218
+ nredo: int = 1, # for compatibility only
205
219
  use_triton: bool | None = None,
206
220
  ):
207
221
  self.d = d
@@ -217,8 +231,11 @@ class FastKMeans:
217
231
  self.dtype = dtype
218
232
  self.pin_gpu_memory = pin_gpu_memory
219
233
  self.verbose = verbose
220
- if use_triton is not False:
221
- use_triton = _is_bfloat16_supported(self.device) # assume triton is supported if GPU supports bfloat16
234
+ if use_triton is None:
235
+ # assume triton kernel is supported if GPU supports bfloat16
236
+ use_triton = HAS_TRITON and _is_bfloat16_supported(self.device)
237
+ if use_triton and not HAS_TRITON:
238
+ raise ValueError("Triton is not available. Please install Triton and try again.")
222
239
  self.use_triton = use_triton
223
240
  if nredo != 1:
224
241
  raise ValueError("nredo must be 1, redos not currently supported")
@@ -237,10 +254,10 @@ class FastKMeans:
237
254
 
238
255
  # Move data to PyTorch CPU Tensor
239
256
  data_torch = torch.from_numpy(data)
240
- data_norms_torch = (data_torch ** 2).sum(dim=1)
257
+ data_norms_torch = (data_torch**2).sum(dim=1)
241
258
 
242
259
  device = _get_device(self.device)
243
- if device == 'cuda' and self.pin_gpu_memory:
260
+ if device == "cuda" and self.pin_gpu_memory:
244
261
  data_torch = data_torch.pin_memory()
245
262
  data_norms_torch = data_norms_torch.pin_memory()
246
263
 
@@ -275,33 +292,33 @@ class FastKMeans:
275
292
  -------
276
293
  labels : np.ndarray of shape (n_samples,), int64
277
294
  """
278
- if self.use_triton:
279
- from fastkmeans.triton_kernels import chunked_kmeans_kernel
280
295
  if self.centroids is None:
281
296
  raise RuntimeError("Must call train() or fit() before predict().")
282
297
 
283
298
  data_torch = torch.from_numpy(data)
284
- data_norms_torch = (data_torch ** 2).sum(dim=1)
299
+ data_norms_torch = (data_torch**2).sum(dim=1)
285
300
 
286
301
  # We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
287
302
  centroids_torch = torch.from_numpy(self.centroids)
288
303
  centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
289
- centroid_norms = (centroids_torch ** 2).sum(dim=1)
304
+ centroid_norms = (centroids_torch**2).sum(dim=1)
290
305
 
291
306
  n_samples = data_torch.shape[0]
292
- labels = torch.empty(n_samples, dtype=torch.long, device='cpu')
307
+ labels = torch.empty(n_samples, dtype=torch.long, device="cpu")
293
308
 
294
309
  start_idx = 0
295
310
  while start_idx < n_samples:
296
311
  end_idx = min(start_idx + self.chunk_size_data, n_samples)
297
312
 
298
313
  data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
299
- data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
314
+ data_chunk_norms = data_norms_torch[start_idx:end_idx].to(
315
+ device=self.device, dtype=torch.float32, non_blocking=True
316
+ )
300
317
  batch_size = data_chunk.size(0)
301
318
  best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
302
319
 
303
320
  if self.use_triton:
304
- chunked_kmeans_kernel(
321
+ triton_kmeans(
305
322
  data_chunk,
306
323
  data_chunk_norms,
307
324
  centroids_torch,
@@ -309,7 +326,7 @@ class FastKMeans:
309
326
  best_ids,
310
327
  )
311
328
  else:
312
- best_dist = torch.full((batch_size,), float('inf'), device=self.device, dtype=torch.float32)
329
+ best_dist = torch.full((batch_size,), float("inf"), device=self.device, dtype=torch.float32)
313
330
  c_start = 0
314
331
  k = centroids_torch.shape[0]
315
332
  while c_start < k:
@@ -323,10 +340,10 @@ class FastKMeans:
323
340
  local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
324
341
  improved_mask = local_min_vals < best_dist
325
342
  best_dist[improved_mask] = local_min_vals[improved_mask]
326
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
343
+ best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
327
344
  c_start = c_end
328
345
 
329
- labels[start_idx:end_idx] = best_ids.to('cpu')
346
+ labels[start_idx:end_idx] = best_ids.to("cpu")
330
347
  start_idx = end_idx
331
348
 
332
349
  return labels.numpy()
@@ -1,105 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+
1
5
  import torch
2
6
  import triton
3
7
  import triton.language as tl
4
8
 
5
9
 
10
+ @contextmanager
11
+ def device_guard(tensor: torch.Tensor):
12
+ """Context manager to ensure that the Triton kernel launches on the correct device."""
13
+ if tensor.device.type == "cuda": # NVIDIA or AMD/ROCm
14
+ with torch.cuda.device_of(tensor):
15
+ yield
16
+ elif tensor.device.type == "xpu": # Intel GPUs
17
+ with torch.xpu.device_of(tensor):
18
+ yield
19
+ else: # CPU or other back-ends
20
+ yield
21
+
22
+
6
23
  @triton.heuristics(
7
24
  {
8
- "BLOCK_M": lambda a: 128 if a["D"] <= 128 else (64 if a["D"] <= 512 else 32),
9
- "BLOCK_N": lambda a: 64 if a["D"] <= 128 else 32,
10
- "num_warps": lambda a: 4 if a["D"] <= 64 else 8,
11
- "num_stages": lambda a: 2 if a["D"] < 64 else 1,
25
+ "BLOCK_M": lambda x: 128 if x["D"] <= 384 else 64,
26
+ "BLOCK_N": lambda x: 128 if x["D"] <= 384 else 64,
27
+ "BLOCK_K": lambda x: 16 if x["D"] <= 32 or x["D"] > 384 else 32,
28
+ "GROUP_SIZE_M": lambda x: 8 if x["D"] <= 32 else 16,
29
+ "num_warps": lambda x: 4,
12
30
  }
13
31
  )
14
32
  @triton.jit
15
- def _chunked_kmeans_kernel(
16
- data_ptr, # [B, D], row-major
17
- x_norm_ptr, # [B], precomputed L2 norms of data
18
- centroids_ptr, # [C, D], row-major
19
- centroids_sqnorm_ptr, # [C], precomputed L2 norms of centroids
20
- best_ids_ptr, # [B], int32 (to store best centroid indices)
21
- B, # number of data points
22
- C, # number of centroids
23
- D: tl.constexpr, # dimension, or number of features
33
+ def _kmeans_kernel(
34
+ x_ptr,
35
+ x_norm_ptr,
36
+ c_ptr,
37
+ c_norm_ptr,
38
+ best_dist_ptr,
39
+ best_idx_ptr,
40
+ B,
41
+ C,
42
+ D: tl.constexpr,
24
43
  BLOCK_M: tl.constexpr,
25
44
  BLOCK_N: tl.constexpr,
45
+ BLOCK_K: tl.constexpr,
46
+ GROUP_SIZE_M: tl.constexpr,
26
47
  ):
27
- """
28
- Each Triton block processes BLOCK_M rows of data. The kernel:
29
- 1) loads those rows (and their precomputed norms from x_norm_ptr),
30
- 2) loops over all centroids in chunks of BLOCK_N,
31
- 3) computes distances, finds the best centroid,
32
- 4) writes out the best centroid index for each data point.
33
- """
34
- # 1) Identify which data rows this block handles
35
- block_id = tl.program_id(axis=0)
36
- row_start = block_id * BLOCK_M
48
+ # Map flat CTA id to (pid_m, pid_n) in “grouped” launch order
49
+ pid = tl.program_id(axis=0)
50
+
51
+ num_pid_m = tl.cdiv(B, BLOCK_M) # row-tiles
52
+ num_pid_n = tl.cdiv(C, BLOCK_N) # centroid-tiles
53
+
54
+ # Super-group into GROUP_SIZE_M blocks to minimize loading from global memory
55
+ num_pid_in_grp = GROUP_SIZE_M * num_pid_n
56
+ first_pid_m = (pid // num_pid_in_grp) * GROUP_SIZE_M
57
+ group_rows = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
+
59
+ pid_m = first_pid_m + ((pid % num_pid_in_grp) % group_rows) # row-tile index
60
+ pid_n = (pid % num_pid_in_grp) // group_rows # centroid-tile index
61
+
62
+ row_start = pid_m * BLOCK_M
63
+ col_start = pid_n * BLOCK_N
64
+
37
65
  rows = row_start + tl.arange(0, BLOCK_M)
38
- mask = rows < B
66
+ cols = col_start + tl.arange(0, BLOCK_N)
39
67
 
40
- # 2) Load data rows and precomputed x_norm: shape [BLOCK_M, D]
41
- row_offsets = rows[:, None] * D + tl.arange(0, D)
42
- x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
68
+ row_mask = rows < B
69
+ col_mask = cols < C
43
70
 
44
- # shape: [BLOCK_M]
45
- x_norm = tl.load(x_norm_ptr + rows, mask=mask, other=0.0)
71
+ # load norms
72
+ x_n = tl.load(x_norm_ptr + rows, mask=row_mask, other=0.0) # [BM]
73
+ c_n = tl.load(c_norm_ptr + cols, mask=col_mask, other=0.0) # [BN]
46
74
 
47
- # Prepare "best distance" + "best index"
48
- best_dist = tl.full([BLOCK_M], 1e38, dtype=tl.float32)
49
- best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
75
+ # pipelined K‑loop, will hold partial dot‑products in registers
76
+ dot_acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
50
77
 
51
- # 3) Iterate over the centroids in chunks of BLOCK_N
52
- for chunk in range(0, C, BLOCK_N):
53
- cids = chunk + tl.arange(0, BLOCK_N)
54
- c_mask = cids < C
78
+ # compute matmul tiled across SMs
79
+ for k0 in range(0, D, BLOCK_K):
80
+ k_range = k0 + tl.arange(0, BLOCK_K)
55
81
 
56
- # Load sub-block of centroids: shape [BLOCK_N, D]
57
- c_offsets = cids[:, None] * D + tl.arange(0, D)
58
- cvals = tl.load(centroids_ptr + c_offsets, mask=c_mask[:, None], other=0.0).to(x.dtype)
82
+ # load X slice
83
+ x_ptrs = x_ptr + rows[:, None] * D + k_range[None, :]
84
+ xk = tl.load(x_ptrs, mask=row_mask[:, None]).to(tl.float16)
59
85
 
60
- # Load centroid norms: shape [BLOCK_N]
61
- c_sqnorm = tl.load(centroids_sqnorm_ptr + cids, mask=c_mask, other=0.0).to(x.dtype)
86
+ # load C slice
87
+ c_ptrs = c_ptr + cols[:, None] * D + k_range[None, :]
88
+ ck = tl.load(c_ptrs, mask=col_mask[:, None]).to(tl.float16)
62
89
 
63
- # Compute distance = x_norm + c_sqnorm - 2 * dot(x, c)
64
- dots = tl.dot(x, tl.trans(cvals)) # shape [BLOCK_M, BLOCK_N]
65
- dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
90
+ # accumulate
91
+ dot_acc += tl.dot(xk, tl.trans(ck), out_dtype=tl.float32)
66
92
 
67
- # Find the argmin along the BLOCK_N dimension
68
- local_min_vals, local_min_idx = tl.min(dist_chunk, axis=1, return_indices=True)
93
+ # finish distance formula
94
+ dist = tl.fma(dot_acc, -2.0, x_n[:, None] + c_n[None, :]) # [BM, BN]
69
95
 
70
- improved = local_min_vals < best_dist
71
- best_dist = tl.where(improved, local_min_vals, best_dist)
72
- best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
96
+ # local arg‑min (inside this tile)
97
+ tile_min, tile_idx = tl.min(dist, axis=1, return_indices=True)
73
98
 
74
- # 4) Write out the best centroid indices
75
- tl.store(best_ids_ptr + rows, best_idx, mask=mask)
99
+ # compete with global best using atomics
100
+ prev = tl.atomic_min(best_dist_ptr + rows, tile_min, mask=row_mask)
101
+ improved = tile_min < prev
76
102
 
103
+ # update best_ids
104
+ tl.store(best_idx_ptr + rows, tl.where(improved, col_start + tile_idx, tl.load(best_idx_ptr + rows)), mask=row_mask)
77
105
 
78
- def chunked_kmeans_kernel(
106
+
107
+ def triton_kmeans(
79
108
  data_chunk: torch.Tensor,
80
109
  data_chunk_norms: torch.Tensor,
81
110
  centroids: torch.Tensor,
82
111
  centroids_sqnorm: torch.Tensor,
83
112
  best_ids: torch.Tensor,
84
113
  ):
85
- """
86
- Launches the Triton kernel to assign each point to its nearest centroid in one pass.
87
-
88
- best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
89
- """
90
114
  B, D = data_chunk.shape
91
115
  C = centroids.shape[0]
116
+ best_dist = torch.full((B,), 1e38, device=data_chunk.device, dtype=torch.float32)
92
117
 
93
118
  def grid(meta):
94
- return (triton.cdiv(B, meta["BLOCK_M"]),)
95
-
96
- _chunked_kmeans_kernel[grid](
97
- data_chunk,
98
- data_chunk_norms,
99
- centroids,
100
- centroids_sqnorm,
101
- best_ids,
102
- B, # num_points
103
- C, # num_centroids
104
- D, # dimension, or number of features
105
- )
119
+ return (triton.cdiv(B, meta["BLOCK_M"]) * triton.cdiv(C, meta["BLOCK_N"]),) # 1D grid
120
+
121
+ # Without this Triton always tries to launch from device:0 and we get
122
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
123
+ with device_guard(data_chunk):
124
+ _kmeans_kernel[grid](
125
+ data_chunk,
126
+ data_chunk_norms,
127
+ centroids,
128
+ centroids_sqnorm,
129
+ best_dist,
130
+ best_ids,
131
+ B,
132
+ C,
133
+ D,
134
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastkmeans
3
- Version: 0.3.0
3
+ Version: 0.5.0
4
4
  Summary: Add your description here
5
5
  Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
6
6
  Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
@@ -14,15 +14,16 @@ Dynamic: license-file
14
14
 
15
15
  # fastkmeans
16
16
 
17
- ![Python Versions](https://img.shields.io/badge/Python-3.8_3.9_3.10_3.11_3.12_3.13-blue)
17
+ ![Python Versions](https://img.shields.io/badge/Python-3.9_3.10_3.11_3.12_3.13-blue)
18
18
  [![Twitter Follow](https://img.shields.io/twitter/follow/bclavie?style=social)](https://twitter.com/bclavie)
19
+ [![Twitter Follow](https://img.shields.io/twitter/follow/benjamin_warner?style=social)](https://twitter.com/benjamin_warner)
19
20
  <!-- [![Downloads](https://static.pepy.tech/badge/fastkmeans/month)](https://pepy.tech/project/fastkmeans) -->
20
21
 
21
22
  _A fast and efficient k-means implementation for PyTorch, with support for GPU and CPU._
22
23
 
23
24
  ---
24
25
 
25
- Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" pytorch native k-means clustering. It's compatible with both CPU and GPU, matching or outspeeding `faiss` (except in multi-GPU settings), and it comes without any install woes, relying on just two dependencies you already have installed anyway: `torch` and `numpy`.
26
+ Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" PyTorch native k-means clustering. It's compatible with any PyTorch-compatible CPU or GPU, matching or outperforming `faiss` by ~4-5× on a single GPU, and is without install woes, relying on just two dependencies you already have installed: `torch` and `numpy`.
26
27
 
27
28
  ### Get started
28
29
 
@@ -43,18 +44,22 @@ There's very, very little to this library. It provides a single interface, `Fast
43
44
 
44
45
  #### Behaviour
45
46
 
46
- Whenever possible, the library will attempt to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see what the arguments are, as they're very straightforward. The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
47
+ Whenever possible, the library attempts to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/main/fastkmeans/kmeans.py#L170) to see what the arguments are, as they are straightforward.
48
+
49
+ The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
47
50
 
48
51
  #### Chunking
49
52
 
50
53
  The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
51
54
 
55
+ > Note: See the Triton section below for triton specific chunking details.
56
+
52
57
  ### Why `fastkmeans`?
53
58
 
54
59
  The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
55
60
 
56
61
  - `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
57
- - `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch: in fact, it's faster than this library! However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
62
+ - `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch. However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
58
63
  - `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
59
64
 
60
65
  There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
@@ -65,8 +70,42 @@ There are some libraries (such as NVidia's own implementations), but they again
65
70
  - The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
66
71
  - The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
67
72
 
73
+ ### Triton Kernel
74
+
75
+ `fastkmeans`'s Triton kmeans kernel is ~4-5 times faster than single-GPU `faiss` or `fastkmeans`'s PyTorch backend. On a modern GPU (Ampere or newer), the Triton backend is enabled by default.
76
+
77
+ While the Triton kernel uses significantly less memory than the PyTorch implementation, increasing the chunk size above 512K can result in slower performance.
78
+
68
79
  ### Speed
69
80
 
70
- Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, 131072 and 262144 clusters, each with w/ cluster_size*100 data points).
81
+ Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, and 131072 clusters, each with w/ cluster_size*100 data points).
71
82
 
72
83
  ![fastkmeans benchmark](./benchmark_plots/4090_benchmark.png)
84
+
85
+ #### Benchmarking
86
+
87
+ To benchmark `fastkmeans` against `faiss` on your own machine, install faiss and PyTorch 2.5 via the `bench_env.yaml` Conda environment:
88
+
89
+ ```bash
90
+ conda env create -f bench_env.yaml
91
+ conda activate fastkmeans
92
+ pip install fastkmeans
93
+ ```
94
+
95
+ Then, run the benchmark script:
96
+
97
+ ```bash
98
+ CUDA_VISIBLE_DEVICES=0 python speedbench.py --do-faiss --do-fastkmeans --do-fastkmeans-triton --do-evals
99
+ ```
100
+
101
+ ### Citation
102
+
103
+ If you use fastmeans and want to/need to cite it in your work, please feel free to use the citation below:
104
+
105
+ ```bibtex
106
+ @misc{fastkmeans2025,
107
+ author = {Benjamin Clavié and Benjamin Warner},
108
+ title = {fastkmeans: Accelerated KMeans Clustering in PyTorch and Triton},
109
+ year = {2025},
110
+ howpublished = {\url{https://github.com/AnswerDotAI/fastkmeans/}}
111
+ }
@@ -0,0 +1,8 @@
1
+ fastkmeans/__init__.py,sha256=wdSM-ig64WXThY0P7AVKSnyAt7trUuSVffIvbzhez7Y,79
2
+ fastkmeans/kmeans.py,sha256=Z7GEP3oWQx9NCg0poDzTvBrz3T4K8YWhKR8AKFjMEcY,14135
3
+ fastkmeans/triton_kernels.py,sha256=LYErKjjvwA4lGAoawFs_KM5vDuCyhUDfPRBiTsrVXdU,4155
4
+ fastkmeans-0.5.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
+ fastkmeans-0.5.0.dist-info/METADATA,sha256=afRBC3fHKnNETs172OirwzY6lcANthwH2fYhchQ5F8I,8043
6
+ fastkmeans-0.5.0.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
7
+ fastkmeans-0.5.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
8
+ fastkmeans-0.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,8 +0,0 @@
1
- fastkmeans/__init__.py,sha256=EKvVznV3rTwe5Yz_tqlF_bJMZsh-5vLSjVvMtCx3Xhk,79
2
- fastkmeans/kmeans.py,sha256=_FbU_avJPxIEfikhPMPgBibj9ckKej21iWJCm-YNEUM,13778
3
- fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
4
- fastkmeans-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
- fastkmeans-0.3.0.dist-info/METADATA,sha256=KOyw3KsNPCdF_6spJl2SByB8mVtZ3I6GE1H8N2KMUpM,6791
6
- fastkmeans-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
7
- fastkmeans-0.3.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
8
- fastkmeans-0.3.0.dist-info/RECORD,,