fastkmeans 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastkmeans/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .kmeans import FastKMeans
2
2
 
3
3
  __all__ = ["FastKMeans"]
4
- __version__ = "0.2.0"
4
+ __version__ = "0.4.0"
fastkmeans/kmeans.py CHANGED
@@ -3,26 +3,33 @@ import time
3
3
  import torch
4
4
  import numpy as np
5
5
 
6
- from fastkmeans.triton_kernels import chunked_kmeans_kernel
6
+ try:
7
+ from fastkmeans.triton_kernels import triton_kmeans
8
+
9
+ HAS_TRITON = True
10
+ except ImportError:
11
+ triton_kmeans = None
12
+ HAS_TRITON = False
13
+
7
14
 
8
15
  def _get_device(preset: str | int | torch.device | None = None):
9
16
  if isinstance(preset, torch.device):
10
17
  return preset
11
18
  if isinstance(preset, str):
12
19
  return torch.device(preset)
13
- if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
20
+ if torch.cuda.is_available(): # cuda currently handles both AMD and NVIDIA GPUs
14
21
  return torch.device(f"cuda:{preset if isinstance(preset, int) and preset < torch.cuda.device_count() else 0}")
15
- if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
16
- return torch.device('mps')
17
- if hasattr(torch, 'xpu') and torch.xpu.is_available():
18
- return torch.device(f'xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}')
19
- return torch.device('cpu')
22
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
23
+ return torch.device("mps")
24
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
25
+ return torch.device(f"xpu:{preset if isinstance(preset, int) and preset < torch.xpu.device_count() else 0}")
26
+ return torch.device("cpu")
20
27
 
21
28
 
22
- def _is_bfloat16_supported(device:torch.device):
23
- if device.type == 'cuda':
29
+ def _is_bfloat16_supported(device: torch.device):
30
+ if device.type == "cuda":
24
31
  return torch.cuda.is_bf16_supported()
25
- elif device.type == 'xpu' and hasattr(torch.xpu, 'is_bf16_supported'):
32
+ elif device.type == "xpu" and hasattr(torch.xpu, "is_bf16_supported"):
26
33
  return torch.xpu.is_bf16_supported()
27
34
  else:
28
35
  return False
@@ -53,8 +60,12 @@ def _kmeans_torch_double_chunked(
53
60
  Where n_samples_used can be smaller than the original if subsampling occurred.
54
61
  """
55
62
 
63
+ if use_triton:
64
+ if not HAS_TRITON:
65
+ raise ImportError("Triton is not available. Please install Triton and try again.")
66
+
56
67
  if dtype is None:
57
- dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
68
+ dtype = torch.float16 if device.type in ["cuda", "xpu"] else torch.float32
58
69
 
59
70
  n_samples_original, n_features = data.shape
60
71
  n_samples = n_samples_original
@@ -76,12 +87,12 @@ def _kmeans_torch_double_chunked(
76
87
  centroids = data[rand_indices].clone().to(device=device, dtype=dtype)
77
88
  prev_centroids = centroids.clone()
78
89
 
79
- labels = torch.empty(n_samples, dtype=torch.int64, device='cpu') # Keep labels on CPU
90
+ labels = torch.empty(n_samples, dtype=torch.int64, device="cpu") # Keep labels on CPU
80
91
 
81
92
  for iteration in range(max_iters):
82
93
  iteration_start_time = time.time()
83
94
 
84
- centroid_norms = (centroids ** 2).sum(dim=1)
95
+ centroid_norms = (centroids**2).sum(dim=1)
85
96
  cluster_sums = torch.zeros((k, n_features), device=device, dtype=torch.float32)
86
97
  cluster_counts = torch.zeros((k,), device=device, dtype=torch.float32)
87
98
 
@@ -95,7 +106,7 @@ def _kmeans_torch_double_chunked(
95
106
  best_ids = torch.zeros((batch_size,), device=device, dtype=torch.int64)
96
107
 
97
108
  if use_triton:
98
- chunked_kmeans_kernel(
109
+ triton_kmeans(
99
110
  data_chunk=data_chunk,
100
111
  data_chunk_norms=data_chunk_norms,
101
112
  centroids=centroids,
@@ -103,7 +114,7 @@ def _kmeans_torch_double_chunked(
103
114
  best_ids=best_ids,
104
115
  )
105
116
  else:
106
- best_dist = torch.full((batch_size,), float('inf'), device=device, dtype=dtype)
117
+ best_dist = torch.full((batch_size,), float("inf"), device=device, dtype=dtype)
107
118
  c_start = 0
108
119
  while c_start < k:
109
120
  c_end = min(c_start + chunk_size_centroids, k)
@@ -116,23 +127,23 @@ def _kmeans_torch_double_chunked(
116
127
  local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
117
128
  improved_mask = local_min_vals < best_dist
118
129
  best_dist[improved_mask] = local_min_vals[improved_mask]
119
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
130
+ best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
120
131
 
121
132
  c_start = c_end
122
133
 
123
134
  cluster_sums.index_add_(0, best_ids, data_chunk.float())
124
135
  cluster_counts.index_add_(0, best_ids, torch.ones_like(best_ids, device=device, dtype=torch.float32))
125
136
 
126
- labels[start_idx:end_idx] = best_ids.to('cpu', non_blocking=True)
137
+ labels[start_idx:end_idx] = best_ids.to("cpu", non_blocking=True)
127
138
  start_idx = end_idx
128
139
 
129
140
  new_centroids = torch.zeros_like(centroids, device=device, dtype=dtype)
130
- non_empty = (cluster_counts > 0)
141
+ non_empty = cluster_counts > 0
131
142
  new_centroids[non_empty] = (cluster_sums[non_empty] / cluster_counts[non_empty].unsqueeze(1)).to(dtype=dtype)
132
143
 
133
144
  empty_ids = (~non_empty).nonzero(as_tuple=True)[0]
134
145
  if len(empty_ids) > 0:
135
- reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device='cpu')
146
+ reinit_indices = torch.randint(0, n_samples, (len(empty_ids),), device="cpu")
136
147
  random_data = data[reinit_indices].to(device=device, dtype=dtype, non_blocking=True)
137
148
  new_centroids[empty_ids] = random_data
138
149
 
@@ -143,14 +154,16 @@ def _kmeans_torch_double_chunked(
143
154
 
144
155
  iteration_time = time.time() - iteration_start_time
145
156
  if verbose:
146
- print(f"Iteration {iteration+1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}")
157
+ print(
158
+ f"Iteration {iteration + 1}/{max_iters} took {iteration_time:.4f}s, total time: {time.time() - iteration_start_time + iteration_time:.4f}s, shift: {shift:.6f}"
159
+ )
147
160
 
148
161
  if shift < tol:
149
162
  if verbose:
150
- print(f"Converged after {iteration+1} iterations (shift: {shift:.6f} < tol: {tol})")
163
+ print(f"Converged after {iteration + 1} iterations (shift: {shift:.6f} < tol: {tol})")
151
164
  break
152
165
 
153
- centroids_cpu = centroids.to('cpu', dtype=torch.float32)
166
+ centroids_cpu = centroids.to("cpu", dtype=torch.float32)
154
167
  return centroids_cpu, labels
155
168
 
156
169
 
@@ -176,9 +189,9 @@ class FastKMeans:
176
189
  max_points_per_centroid : int, optional, default=1_000_000_000
177
190
  If n_samples > k * max_points_per_centroid, the data will be subsampled to exactly
178
191
  k * max_points_per_centroid points before clustering.
179
- chunk_size_data : int, default=50_000
192
+ chunk_size_data : int, default=10,2400
180
193
  Chunk size along the data dimension for assignment/update steps.
181
- chunk_size_centroids : int, default=10_000
194
+ chunk_size_centroids : int, default=10,240
182
195
  Chunk size along the centroid dimension for assignment/update steps.
183
196
  use_triton : bool | None, default=None
184
197
  Use the fast Triton backend for the assignment/update steps.
@@ -193,14 +206,14 @@ class FastKMeans:
193
206
  tol: float = 1e-8,
194
207
  gpu: bool = True,
195
208
  seed: int = 0,
196
- max_points_per_centroid: int = 256,
197
- chunk_size_data: int = 50_000,
198
- chunk_size_centroids: int = 10_000,
209
+ max_points_per_centroid: int = 1_000_000_000,
210
+ chunk_size_data: int = 51_200,
211
+ chunk_size_centroids: int = 10_240,
199
212
  device: str | int | torch.device | None = None,
200
213
  dtype: torch.dtype = None,
201
214
  pin_gpu_memory: bool = True,
202
215
  verbose: bool = False,
203
- nredo: int = 1, # for compatibility only
216
+ nredo: int = 1, # for compatibility only
204
217
  use_triton: bool | None = None,
205
218
  ):
206
219
  self.d = d
@@ -216,8 +229,11 @@ class FastKMeans:
216
229
  self.dtype = dtype
217
230
  self.pin_gpu_memory = pin_gpu_memory
218
231
  self.verbose = verbose
219
- if use_triton is not False:
220
- use_triton = _is_bfloat16_supported(self.device) # assume triton is supported if GPU supports bfloat16
232
+ if use_triton is None:
233
+ # assume triton kernel is supported if GPU supports bfloat16
234
+ use_triton = HAS_TRITON and _is_bfloat16_supported(self.device)
235
+ if use_triton and not HAS_TRITON:
236
+ raise ValueError("Triton is not available. Please install Triton and try again.")
221
237
  self.use_triton = use_triton
222
238
  if nredo != 1:
223
239
  raise ValueError("nredo must be 1, redos not currently supported")
@@ -236,10 +252,10 @@ class FastKMeans:
236
252
 
237
253
  # Move data to PyTorch CPU Tensor
238
254
  data_torch = torch.from_numpy(data)
239
- data_norms_torch = (data_torch ** 2).sum(dim=1)
255
+ data_norms_torch = (data_torch**2).sum(dim=1)
240
256
 
241
257
  device = _get_device(self.device)
242
- if device == 'cuda' and self.pin_gpu_memory:
258
+ if device == "cuda" and self.pin_gpu_memory:
243
259
  data_torch = data_torch.pin_memory()
244
260
  data_norms_torch = data_norms_torch.pin_memory()
245
261
 
@@ -278,27 +294,29 @@ class FastKMeans:
278
294
  raise RuntimeError("Must call train() or fit() before predict().")
279
295
 
280
296
  data_torch = torch.from_numpy(data)
281
- data_norms_torch = (data_torch ** 2).sum(dim=1)
297
+ data_norms_torch = (data_torch**2).sum(dim=1)
282
298
 
283
299
  # We'll do a chunked assignment pass, similar to the main loop, but no centroid updates
284
300
  centroids_torch = torch.from_numpy(self.centroids)
285
301
  centroids_torch = centroids_torch.to(device=self.device, dtype=torch.float32)
286
- centroid_norms = (centroids_torch ** 2).sum(dim=1)
302
+ centroid_norms = (centroids_torch**2).sum(dim=1)
287
303
 
288
304
  n_samples = data_torch.shape[0]
289
- labels = torch.empty(n_samples, dtype=torch.long, device='cpu')
305
+ labels = torch.empty(n_samples, dtype=torch.long, device="cpu")
290
306
 
291
307
  start_idx = 0
292
308
  while start_idx < n_samples:
293
309
  end_idx = min(start_idx + self.chunk_size_data, n_samples)
294
310
 
295
311
  data_chunk = data_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
296
- data_chunk_norms = data_norms_torch[start_idx:end_idx].to(device=self.device, dtype=torch.float32, non_blocking=True)
312
+ data_chunk_norms = data_norms_torch[start_idx:end_idx].to(
313
+ device=self.device, dtype=torch.float32, non_blocking=True
314
+ )
297
315
  batch_size = data_chunk.size(0)
298
316
  best_ids = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
299
317
 
300
318
  if self.use_triton:
301
- chunked_kmeans_kernel(
319
+ triton_kmeans(
302
320
  data_chunk,
303
321
  data_chunk_norms,
304
322
  centroids_torch,
@@ -306,7 +324,7 @@ class FastKMeans:
306
324
  best_ids,
307
325
  )
308
326
  else:
309
- best_dist = torch.full((batch_size,), float('inf'), device=self.device, dtype=torch.float32)
327
+ best_dist = torch.full((batch_size,), float("inf"), device=self.device, dtype=torch.float32)
310
328
  c_start = 0
311
329
  k = centroids_torch.shape[0]
312
330
  while c_start < k:
@@ -320,10 +338,10 @@ class FastKMeans:
320
338
  local_min_vals, local_min_ids = torch.min(dist_chunk, dim=1)
321
339
  improved_mask = local_min_vals < best_dist
322
340
  best_dist[improved_mask] = local_min_vals[improved_mask]
323
- best_ids[improved_mask] = (c_start + local_min_ids[improved_mask])
341
+ best_ids[improved_mask] = c_start + local_min_ids[improved_mask]
324
342
  c_start = c_end
325
343
 
326
- labels[start_idx:end_idx] = best_ids.to('cpu')
344
+ labels[start_idx:end_idx] = best_ids.to("cpu")
327
345
  start_idx = end_idx
328
346
 
329
347
  return labels.numpy()
@@ -1,105 +1,132 @@
1
+ from contextlib import contextmanager
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
4
6
 
5
7
 
8
+ @contextmanager
9
+ def device_guard(tensor: torch.Tensor):
10
+ """Context manager to ensure that the Triton kernel launches on the correct device."""
11
+ if tensor.device.type == "cuda": # NVIDIA or AMD/ROCm
12
+ with torch.cuda.device_of(tensor):
13
+ yield
14
+ elif tensor.device.type == "xpu": # Intel GPUs
15
+ with torch.xpu.device_of(tensor):
16
+ yield
17
+ else: # CPU or other back-ends
18
+ yield
19
+
20
+
6
21
  @triton.heuristics(
7
22
  {
8
- "BLOCK_M": lambda a: 128 if a["D"] <= 128 else (64 if a["D"] <= 512 else 32),
9
- "BLOCK_N": lambda a: 64 if a["D"] <= 128 else 32,
10
- "num_warps": lambda a: 4 if a["D"] <= 64 else 8,
11
- "num_stages": lambda a: 2 if a["D"] < 64 else 1,
23
+ "BLOCK_M": lambda x: 128 if x["D"] <= 384 else 64,
24
+ "BLOCK_N": lambda x: 128 if x["D"] <= 384 else 64,
25
+ "BLOCK_K": lambda x: 16 if x["D"] <= 32 or x["D"] > 384 else 32,
26
+ "GROUP_SIZE_M": lambda x: 8 if x["D"] <= 32 else 16,
27
+ "num_warps": lambda x: 4,
12
28
  }
13
29
  )
14
30
  @triton.jit
15
- def _chunked_kmeans_kernel(
16
- data_ptr, # [B, D], row-major
17
- x_norm_ptr, # [B], precomputed L2 norms of data
18
- centroids_ptr, # [C, D], row-major
19
- centroids_sqnorm_ptr, # [C], precomputed L2 norms of centroids
20
- best_ids_ptr, # [B], int32 (to store best centroid indices)
21
- B, # number of data points
22
- C, # number of centroids
23
- D: tl.constexpr, # dimension, or number of features
31
+ def _kmeans_kernel(
32
+ x_ptr,
33
+ x_norm_ptr,
34
+ c_ptr,
35
+ c_norm_ptr,
36
+ best_dist_ptr,
37
+ best_idx_ptr,
38
+ B,
39
+ C,
40
+ D: tl.constexpr,
24
41
  BLOCK_M: tl.constexpr,
25
42
  BLOCK_N: tl.constexpr,
43
+ BLOCK_K: tl.constexpr,
44
+ GROUP_SIZE_M: tl.constexpr,
26
45
  ):
27
- """
28
- Each Triton block processes BLOCK_M rows of data. The kernel:
29
- 1) loads those rows (and their precomputed norms from x_norm_ptr),
30
- 2) loops over all centroids in chunks of BLOCK_N,
31
- 3) computes distances, finds the best centroid,
32
- 4) writes out the best centroid index for each data point.
33
- """
34
- # 1) Identify which data rows this block handles
35
- block_id = tl.program_id(axis=0)
36
- row_start = block_id * BLOCK_M
46
+ # Map flat CTA id to (pid_m, pid_n) in “grouped” launch order
47
+ pid = tl.program_id(axis=0)
48
+
49
+ num_pid_m = tl.cdiv(B, BLOCK_M) # row-tiles
50
+ num_pid_n = tl.cdiv(C, BLOCK_N) # centroid-tiles
51
+
52
+ # Super-group into GROUP_SIZE_M blocks to minimize loading from global memory
53
+ num_pid_in_grp = GROUP_SIZE_M * num_pid_n
54
+ first_pid_m = (pid // num_pid_in_grp) * GROUP_SIZE_M
55
+ group_rows = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
56
+
57
+ pid_m = first_pid_m + ((pid % num_pid_in_grp) % group_rows) # row-tile index
58
+ pid_n = (pid % num_pid_in_grp) // group_rows # centroid-tile index
59
+
60
+ row_start = pid_m * BLOCK_M
61
+ col_start = pid_n * BLOCK_N
62
+
37
63
  rows = row_start + tl.arange(0, BLOCK_M)
38
- mask = rows < B
64
+ cols = col_start + tl.arange(0, BLOCK_N)
39
65
 
40
- # 2) Load data rows and precomputed x_norm: shape [BLOCK_M, D]
41
- row_offsets = rows[:, None] * D + tl.arange(0, D)
42
- x = tl.load(data_ptr + row_offsets, mask=mask[:, None], other=0.0)
66
+ row_mask = rows < B
67
+ col_mask = cols < C
43
68
 
44
- # shape: [BLOCK_M]
45
- x_norm = tl.load(x_norm_ptr + rows, mask=mask, other=0.0)
69
+ # load norms
70
+ x_n = tl.load(x_norm_ptr + rows, mask=row_mask, other=0.0) # [BM]
71
+ c_n = tl.load(c_norm_ptr + cols, mask=col_mask, other=0.0) # [BN]
46
72
 
47
- # Prepare "best distance" + "best index"
48
- best_dist = tl.full([BLOCK_M], 1e38, dtype=tl.float32)
49
- best_idx = tl.zeros([BLOCK_M], dtype=tl.int64)
73
+ # pipelined K‑loop, will hold partial dot‑products in registers
74
+ dot_acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
50
75
 
51
- # 3) Iterate over the centroids in chunks of BLOCK_N
52
- for chunk in range(0, C, BLOCK_N):
53
- cids = chunk + tl.arange(0, BLOCK_N)
54
- c_mask = cids < C
76
+ # compute matmul tiled across SMs
77
+ for k0 in range(0, D, BLOCK_K):
78
+ k_range = k0 + tl.arange(0, BLOCK_K)
55
79
 
56
- # Load sub-block of centroids: shape [BLOCK_N, D]
57
- c_offsets = cids[:, None] * D + tl.arange(0, D)
58
- cvals = tl.load(centroids_ptr + c_offsets, mask=c_mask[:, None], other=0.0).to(x.dtype)
80
+ # load X slice
81
+ x_ptrs = x_ptr + rows[:, None] * D + k_range[None, :]
82
+ xk = tl.load(x_ptrs, mask=row_mask[:, None]).to(tl.float16)
59
83
 
60
- # Load centroid norms: shape [BLOCK_N]
61
- c_sqnorm = tl.load(centroids_sqnorm_ptr + cids, mask=c_mask, other=0.0).to(x.dtype)
84
+ # load C slice
85
+ c_ptrs = c_ptr + cols[:, None] * D + k_range[None, :]
86
+ ck = tl.load(c_ptrs, mask=col_mask[:, None]).to(tl.float16)
62
87
 
63
- # Compute distance = x_norm + c_sqnorm - 2 * dot(x, c)
64
- dots = tl.dot(x, tl.trans(cvals)) # shape [BLOCK_M, BLOCK_N]
65
- dist_chunk = tl.fma(dots, -2.0, x_norm[:, None] + c_sqnorm[None, :])
88
+ # accumulate
89
+ dot_acc += tl.dot(xk, tl.trans(ck), out_dtype=tl.float32)
66
90
 
67
- # Find the argmin along the BLOCK_N dimension
68
- local_min_vals, local_min_idx = tl.min(dist_chunk, axis=1, return_indices=True)
91
+ # finish distance formula
92
+ dist = tl.fma(dot_acc, -2.0, x_n[:, None] + c_n[None, :]) # [BM, BN]
69
93
 
70
- improved = local_min_vals < best_dist
71
- best_dist = tl.where(improved, local_min_vals, best_dist)
72
- best_idx = tl.where(improved, chunk + local_min_idx, best_idx)
94
+ # local arg‑min (inside this tile)
95
+ tile_min, tile_idx = tl.min(dist, axis=1, return_indices=True)
73
96
 
74
- # 4) Write out the best centroid indices
75
- tl.store(best_ids_ptr + rows, best_idx, mask=mask)
97
+ # compete with global best using atomics
98
+ prev = tl.atomic_min(best_dist_ptr + rows, tile_min, mask=row_mask)
99
+ improved = tile_min < prev
76
100
 
101
+ # update best_ids
102
+ tl.store(best_idx_ptr + rows, tl.where(improved, col_start + tile_idx, tl.load(best_idx_ptr + rows)), mask=row_mask)
77
103
 
78
- def chunked_kmeans_kernel(
104
+
105
+ def triton_kmeans(
79
106
  data_chunk: torch.Tensor,
80
107
  data_chunk_norms: torch.Tensor,
81
108
  centroids: torch.Tensor,
82
109
  centroids_sqnorm: torch.Tensor,
83
110
  best_ids: torch.Tensor,
84
111
  ):
85
- """
86
- Launches the Triton kernel to assign each point to its nearest centroid in one pass.
87
-
88
- best_ids: pre-allocated [B] (int32) to store the best centroid ID of each point.
89
- """
90
112
  B, D = data_chunk.shape
91
113
  C = centroids.shape[0]
114
+ best_dist = torch.full((B,), 1e38, device=data_chunk.device, dtype=torch.float32)
92
115
 
93
116
  def grid(meta):
94
- return (triton.cdiv(B, meta["BLOCK_M"]),)
95
-
96
- _chunked_kmeans_kernel[grid](
97
- data_chunk,
98
- data_chunk_norms,
99
- centroids,
100
- centroids_sqnorm,
101
- best_ids,
102
- B, # num_points
103
- C, # num_centroids
104
- D, # dimension, or number of features
105
- )
117
+ return (triton.cdiv(B, meta["BLOCK_M"]) * triton.cdiv(C, meta["BLOCK_N"]),) # 1D grid
118
+
119
+ # Without this Triton always tries to launch from device:0 and we get
120
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
121
+ with device_guard(data_chunk):
122
+ _kmeans_kernel[grid](
123
+ data_chunk,
124
+ data_chunk_norms,
125
+ centroids,
126
+ centroids_sqnorm,
127
+ best_dist,
128
+ best_ids,
129
+ B,
130
+ C,
131
+ D,
132
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastkmeans
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: Add your description here
5
5
  Author-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
6
6
  Maintainer-email: Ben Clavié <bc@answer.ai>, Benjamin Warner <bw@answer.ai>
@@ -22,7 +22,7 @@ _A fast and efficient k-means implementation for PyTorch, with support for GPU a
22
22
 
23
23
  ---
24
24
 
25
- Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" pytorch native k-means clustering. It's compatible with both CPU and GPU, matching or outspeeding `faiss` (except in multi-GPU settings), and it comes without any install woes, relying on just two dependencies you already have installed anyway: `torch` and `numpy`.
25
+ Welcome to `fastkmeans`! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" PyTorch native k-means clustering. It's compatible with any PyTorch-compatible CPU or GPU, matching or outperforming `faiss` by ~4-5× on a single GPU, and is without install woes, relying on just two dependencies you already have installed: `torch` and `numpy`.
26
26
 
27
27
  ### Get started
28
28
 
@@ -43,18 +43,22 @@ There's very, very little to this library. It provides a single interface, `Fast
43
43
 
44
44
  #### Behaviour
45
45
 
46
- Whenever possible, the library will attempt to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see what the arguments are, as they're very straightforward. The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
46
+ Whenever possible, the library attempts to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the [API docstring](https://github.com/AnswerDotAI/fastkmeans/blob/main/fastkmeans/kmeans.py#L170) to see what the arguments are, as they are straightforward.
47
+
48
+ The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a `tol` parameter, which stops the algorithm when the centroids don't move more than `tol` between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting `tol` to -1.
47
49
 
48
50
  #### Chunking
49
51
 
50
52
  The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments [here](https://github.com/AnswerDotAI/fastkmeans/blob/17c2a1b4cabc84c7bb0cb392fd2d2e3ec4d1b825/fastkmeans/kmeans.py#L132) to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
51
53
 
54
+ > Note: See the Triton section below for triton specific chunking details.
55
+
52
56
  ### Why `fastkmeans`?
53
57
 
54
58
  The main motivation behind `fastkmeans` is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on **large numbers of clusters** for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
55
59
 
56
60
  - `faiss` is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for its `cpu` and `gpu` variant, meaning additional friction for users and the inability to provide a nice default.
57
- - `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch: in fact, it's faster than this library! However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
61
+ - `fast-pytorch-kmeans` is a great library which provides lightning fast kmeans implementation in PyTorch. However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).
58
62
  - `scikit-learn`, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
59
63
 
60
64
  There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
@@ -65,8 +69,30 @@ There are some libraries (such as NVidia's own implementations), but they again
65
69
  - The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
66
70
  - The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust `faiss` available on their machine anyway.
67
71
 
72
+ ### Triton Kernel
73
+
74
+ `fastkmeans`'s Triton kmeans kernel is ~4-5 times faster than single-GPU `faiss` or `fastkmeans`'s PyTorch backend. On a modern GPU (Ampere or newer), the Triton backend is enabled by default.
75
+
76
+ While the Triton kernel uses significantly less memory than the PyTorch implementation, increasing the chunk size above 512K can result in slower performance.
77
+
68
78
  ### Speed
69
79
 
70
- Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, 131072 and 262144 clusters, each with w/ cluster_size*100 data points).
80
+ Below is `fastkmeans` benchmarked against `faiss` on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, and 131072 clusters, each with w/ cluster_size*100 data points).
71
81
 
72
82
  ![fastkmeans benchmark](./benchmark_plots/4090_benchmark.png)
83
+
84
+ #### Benchmarking
85
+
86
+ To benchmark `fastkmeans` against `faiss` on your own machine, install faiss and PyTorch 2.5 via the `bench_env.yaml` Conda environment:
87
+
88
+ ```bash
89
+ conda env create -f bench_env.yaml
90
+ conda activate fastkmeans
91
+ pip install fastkmeans
92
+ ```
93
+
94
+ Then, run the benchmark script:
95
+
96
+ ```bash
97
+ CUDA_VISIBLE_DEVICES=0 python speedbench.py --do-faiss --do-fastkmeans --do-fastkmeans-triton --do-evals
98
+ ```
@@ -0,0 +1,8 @@
1
+ fastkmeans/__init__.py,sha256=NcMnvnLDRqvTrAXfvVJB0M845Fe8imRzI90dAOlc3MY,79
2
+ fastkmeans/kmeans.py,sha256=l8W79LZszDiDSS8H2jFke8Wp_blFQdslCkr4Oo2CKik,14099
3
+ fastkmeans/triton_kernels.py,sha256=vwKUApwKsGWAhN9gMh9rWzandWoEt6rV187xLAW2t3U,4119
4
+ fastkmeans-0.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
+ fastkmeans-0.4.0.dist-info/METADATA,sha256=wKNRR_pxOUxTM4cAH2jCLnt7YgUZYYAc1kRG-4YZ8eE,7552
6
+ fastkmeans-0.4.0.dist-info/WHEEL,sha256=GHB6lJx2juba1wDgXDNlMTyM13ckjBMKf-OnwgKOCtA,91
7
+ fastkmeans-0.4.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
8
+ fastkmeans-0.4.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,8 +0,0 @@
1
- fastkmeans/__init__.py,sha256=mPLGfhksBBfB1dTkPBPrPg5E1qixixl_Qritc6A10AI,79
2
- fastkmeans/kmeans.py,sha256=RLDeaCIbKpCJBzV2XO3JPgLwjrhJ6vtA997Gyzm_GyA,13651
3
- fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
4
- fastkmeans-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
- fastkmeans-0.2.0.dist-info/METADATA,sha256=2hsTZr0t2_CNhCyrECIBk2S4ZB7ytF0YaDqgcgHDvDc,6791
6
- fastkmeans-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
7
- fastkmeans-0.2.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
8
- fastkmeans-0.2.0.dist-info/RECORD,,