fastkmeans 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastkmeans/__init__.py +1 -1
- fastkmeans/kmeans.py +5 -2
- {fastkmeans-0.2.0.dist-info → fastkmeans-0.3.0.dist-info}/METADATA +1 -1
- fastkmeans-0.3.0.dist-info/RECORD +8 -0
- {fastkmeans-0.2.0.dist-info → fastkmeans-0.3.0.dist-info}/WHEEL +1 -1
- fastkmeans-0.2.0.dist-info/RECORD +0 -8
- {fastkmeans-0.2.0.dist-info → fastkmeans-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {fastkmeans-0.2.0.dist-info → fastkmeans-0.3.0.dist-info}/top_level.txt +0 -0
fastkmeans/__init__.py
CHANGED
fastkmeans/kmeans.py
CHANGED
@@ -3,8 +3,6 @@ import time
|
|
3
3
|
import torch
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
7
|
-
|
8
6
|
def _get_device(preset: str | int | torch.device | None = None):
|
9
7
|
if isinstance(preset, torch.device):
|
10
8
|
return preset
|
@@ -53,6 +51,9 @@ def _kmeans_torch_double_chunked(
|
|
53
51
|
Where n_samples_used can be smaller than the original if subsampling occurred.
|
54
52
|
"""
|
55
53
|
|
54
|
+
if use_triton:
|
55
|
+
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
56
|
+
|
56
57
|
if dtype is None:
|
57
58
|
dtype = torch.float16 if device.type in ['cuda', 'xpu'] else torch.float32
|
58
59
|
|
@@ -274,6 +275,8 @@ class FastKMeans:
|
|
274
275
|
-------
|
275
276
|
labels : np.ndarray of shape (n_samples,), int64
|
276
277
|
"""
|
278
|
+
if self.use_triton:
|
279
|
+
from fastkmeans.triton_kernels import chunked_kmeans_kernel
|
277
280
|
if self.centroids is None:
|
278
281
|
raise RuntimeError("Must call train() or fit() before predict().")
|
279
282
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
fastkmeans/__init__.py,sha256=EKvVznV3rTwe5Yz_tqlF_bJMZsh-5vLSjVvMtCx3Xhk,79
|
2
|
+
fastkmeans/kmeans.py,sha256=_FbU_avJPxIEfikhPMPgBibj9ckKej21iWJCm-YNEUM,13778
|
3
|
+
fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
|
4
|
+
fastkmeans-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
+
fastkmeans-0.3.0.dist-info/METADATA,sha256=KOyw3KsNPCdF_6spJl2SByB8mVtZ3I6GE1H8N2KMUpM,6791
|
6
|
+
fastkmeans-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
7
|
+
fastkmeans-0.3.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
+
fastkmeans-0.3.0.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
fastkmeans/__init__.py,sha256=mPLGfhksBBfB1dTkPBPrPg5E1qixixl_Qritc6A10AI,79
|
2
|
-
fastkmeans/kmeans.py,sha256=RLDeaCIbKpCJBzV2XO3JPgLwjrhJ6vtA997Gyzm_GyA,13651
|
3
|
-
fastkmeans/triton_kernels.py,sha256=iN8khhoaQGJ08LQy5iz4VGEXRCtvFKDfCgyGzwVqjgw,3698
|
4
|
-
fastkmeans-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
-
fastkmeans-0.2.0.dist-info/METADATA,sha256=2hsTZr0t2_CNhCyrECIBk2S4ZB7ytF0YaDqgcgHDvDc,6791
|
6
|
-
fastkmeans-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
7
|
-
fastkmeans-0.2.0.dist-info/top_level.txt,sha256=B3Zd2-kEAH_hN0hFUWgo5lO-TH7ppVol_WQ5ZT1H0js,11
|
8
|
-
fastkmeans-0.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|