mlx-mol-cluster 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Tony E. Lin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx-mol-cluster
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary:
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Author: Tony Eight Lin
|
|
8
|
+
Requires-Python: >=3.11
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
14
|
+
Requires-Dist: mlx (>=0.31.0,<0.32.0)
|
|
15
|
+
Requires-Dist: numpy (>=2.4.2,<3.0.0)
|
|
16
|
+
Requires-Dist: pandas (>=3.0.1,<4.0.0)
|
|
17
|
+
Requires-Dist: rdkit (>=2025.9.5,<2026.0.0)
|
|
18
|
+
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
[](https://github.com/tlint101/MLXMolCluster)
|
|
22
|
+
[](https://github.com/ml-explore/mlx)
|
|
23
|
+
|
|
24
|
+
# MLXMolCluster
|
|
25
|
+
|
|
26
|
+
Leverage Apple Silicon to cluster molecules using MLX.
|
|
27
|
+
|
|
28
|
+
At the time of writing, the project contains two clustering methods:
|
|
29
|
+
- Butina
|
|
30
|
+
- KMeans
|
|
31
|
+
|
|
32
|
+
Additional clustering methods will be added over time.
|
|
33
|
+
|
|
34
|
+
Examples have been written and can be found [here](tutorial).
|
|
35
|
+
|
|
36
|
+
## Installation
|
|
37
|
+
Clone and install locally:
|
|
38
|
+
```python
|
|
39
|
+
pip install git+https://github.com/tlint101/MLXMolCluster.git
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Example
|
|
43
|
+
The following is an example of clustering molecules using Butina on MLX.
|
|
44
|
+
```python
|
|
45
|
+
# generate molecular fingerprints
|
|
46
|
+
fp_gen = rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=1024)
|
|
47
|
+
rdkit_fps = [fp_gen.GetFingerprint(mol) for mol in mol_list]
|
|
48
|
+
|
|
49
|
+
# convert to mlx arrays
|
|
50
|
+
mlx_fp = fp_to_mlx(rdkit_fps)
|
|
51
|
+
|
|
52
|
+
# Butina cluster
|
|
53
|
+
butina_mlx = butina(mlx_fp)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
A speed comparison can be seen at the [tutorial section](tutorial). The runs were performed on a M2 Pro chip
|
|
57
|
+
(10 CPU, 16 GPU)
|
|
58
|
+

|
|
59
|
+
|
|
60
|
+
**NOTE:** The figure can be misleading. The figure shows the clustering speed of already generated molecular
|
|
61
|
+
fingerprints. The main bottleneck of clustering remains on the generation of molecular fingerprints. This is done on the
|
|
62
|
+
CPU before being converted to the GPU. Depending on the number of molecules, this can be time intensive.
|
|
63
|
+
|
|
64
|
+
### Additional
|
|
65
|
+
Collaborations are welcome!
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
[](https://github.com/tlint101/MLXMolCluster)
|
|
2
|
+
[](https://github.com/ml-explore/mlx)
|
|
3
|
+
|
|
4
|
+
# MLXMolCluster
|
|
5
|
+
|
|
6
|
+
Leverage Apple Silicon to cluster molecules using MLX.
|
|
7
|
+
|
|
8
|
+
At the time of writing, the project contains two clustering methods:
|
|
9
|
+
- Butina
|
|
10
|
+
- KMeans
|
|
11
|
+
|
|
12
|
+
Additional clustering methods will be added over time.
|
|
13
|
+
|
|
14
|
+
Examples have been written and can be found [here](tutorial).
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
Clone and install locally:
|
|
18
|
+
```python
|
|
19
|
+
pip install git+https://github.com/tlint101/MLXMolCluster.git
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Example
|
|
23
|
+
The following is an example of clustering molecules using Butina on MLX.
|
|
24
|
+
```python
|
|
25
|
+
# generate molecular fingerprints
|
|
26
|
+
fp_gen = rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=1024)
|
|
27
|
+
rdkit_fps = [fp_gen.GetFingerprint(mol) for mol in mol_list]
|
|
28
|
+
|
|
29
|
+
# convert to mlx arrays
|
|
30
|
+
mlx_fp = fp_to_mlx(rdkit_fps)
|
|
31
|
+
|
|
32
|
+
# Butina cluster
|
|
33
|
+
butina_mlx = butina(mlx_fp)
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
A speed comparison can be seen at the [tutorial section](tutorial). The runs were performed on a M2 Pro chip
|
|
37
|
+
(10 CPU, 16 GPU)
|
|
38
|
+

|
|
39
|
+
|
|
40
|
+
**NOTE:** The figure can be misleading. The figure shows the clustering speed of already generated molecular
|
|
41
|
+
fingerprints. The main bottleneck of clustering remains on the generation of molecular fingerprints. This is done on the
|
|
42
|
+
CPU before being converted to the GPU. Depending on the number of molecules, this can be time intensive.
|
|
43
|
+
|
|
44
|
+
### Additional
|
|
45
|
+
Collaborations are welcome!
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .cluster import *
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import mlx.core as mx
|
|
3
|
+
from rdkit.ML.Cluster import Butina
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def fp_to_mlx(fp: list) -> mx.array:
|
|
8
|
+
"""
|
|
9
|
+
Convert a list of fingerprints into MLX array
|
|
10
|
+
:param fp: list
|
|
11
|
+
A list of molecular fingerprints calcualted using RDKit.
|
|
12
|
+
:return:
|
|
13
|
+
"""
|
|
14
|
+
# convert list to array then to mx.array
|
|
15
|
+
arr = np.array(fp)
|
|
16
|
+
fp_array = mx.array(arr).astype(mx.float32)
|
|
17
|
+
return fp_array
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_tanimoto(fps: Optional[mx.array] = None, chunk_size: int = 5000, matrix: bool = False):
|
|
21
|
+
""""
|
|
22
|
+
Calculate Tanimoto similarity score between an mx.array of molecules.
|
|
23
|
+
fps: Optional[mx.array]
|
|
24
|
+
An mx.array of molecular fingerprints.
|
|
25
|
+
chunk_size: int
|
|
26
|
+
The number of chunks to process keep under GPU buffer limits.
|
|
27
|
+
matrix: bool
|
|
28
|
+
Whether to output a matrix or a flattened np.array.
|
|
29
|
+
"""
|
|
30
|
+
n = fps.shape[0]
|
|
31
|
+
bits_set = mx.sum(fps, axis=1, keepdims=True)
|
|
32
|
+
|
|
33
|
+
results = []
|
|
34
|
+
|
|
35
|
+
# process by chunks
|
|
36
|
+
for i in range(0, n, chunk_size):
|
|
37
|
+
end_i = min(i + chunk_size, n)
|
|
38
|
+
chunk_fps = fps[i:end_i] # shape: (chunk, bits)
|
|
39
|
+
chunk_bits = bits_set[i:end_i] # shape: (chunk, 1)
|
|
40
|
+
|
|
41
|
+
# computer intersection
|
|
42
|
+
intersections = mx.matmul(chunk_fps, fps.T)
|
|
43
|
+
|
|
44
|
+
# Tanimoto calc
|
|
45
|
+
union = chunk_bits + bits_set.T - intersections
|
|
46
|
+
tanimoto_sim = intersections / (union + 1e-7)
|
|
47
|
+
dist_chunk = 1.0 - tanimoto_sim
|
|
48
|
+
|
|
49
|
+
if matrix is not True:
|
|
50
|
+
# only keep lower triangle
|
|
51
|
+
for row_idx in range(i, end_i):
|
|
52
|
+
# row in the distance chunk is (row_idx - i)
|
|
53
|
+
actual_row = dist_chunk[row_idx - i, :row_idx]
|
|
54
|
+
results.append(np.array(actual_row)) # move to CPU/NumPy
|
|
55
|
+
else:
|
|
56
|
+
results.append(dist_chunk)
|
|
57
|
+
|
|
58
|
+
# convert into numpy array
|
|
59
|
+
return np.concatenate(results)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def butina(fingerprints: mx.array = None, cutoff: float = 0.2, chunk_size: int = 5000) -> list:
|
|
63
|
+
"""
|
|
64
|
+
Cluster fingerprints.
|
|
65
|
+
:param fingerprints: mx.array
|
|
66
|
+
A list of RDKit molecular fingerprints.
|
|
67
|
+
:param cutoff: float
|
|
68
|
+
Set the cluster threshold.
|
|
69
|
+
:param chunk_size: int
|
|
70
|
+
The number of chunks to process keep under GPU buffer limits.
|
|
71
|
+
:return:
|
|
72
|
+
"""
|
|
73
|
+
# tanimoto matrix
|
|
74
|
+
distance_matrix = get_tanimoto(fingerprints, chunk_size=chunk_size, matrix=False)
|
|
75
|
+
# cluster
|
|
76
|
+
clusters = Butina.ClusterData(distance_matrix, len(fingerprints), cutoff, isDistData=True)
|
|
77
|
+
clusters = sorted(clusters, key=len, reverse=True)
|
|
78
|
+
return clusters
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class KMeans:
|
|
82
|
+
def __init__(self, n_clusters: int = 8, init: str = 'k-means++', n_init: int = 1, max_iter: int = 300,
|
|
83
|
+
tol: float = 1e-4, random_state: int = None):
|
|
84
|
+
"""
|
|
85
|
+
Initialize the KMeans object. Params should be similar to what can be found on SKlearn:
|
|
86
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html
|
|
87
|
+
:param n_clusters: int
|
|
88
|
+
Number of clusters and centroids to generate.
|
|
89
|
+
:param init: str
|
|
90
|
+
Method of initialization. At this moment, only 'k-means++' is supported.
|
|
91
|
+
:param n_init: int
|
|
92
|
+
Number of times the k-means algorithm is run with different centroid seeds.
|
|
93
|
+
:param max_iter: int
|
|
94
|
+
Maximum number of iterations of the k-means algorithm for a single run.
|
|
95
|
+
:param tol: float
|
|
96
|
+
Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two
|
|
97
|
+
consecutive iterations to declare convergence.
|
|
98
|
+
:param random_state: int
|
|
99
|
+
Determines random number generation for centroid initialization.
|
|
100
|
+
"""
|
|
101
|
+
self.n_clusters = n_clusters
|
|
102
|
+
self.init = init
|
|
103
|
+
self.n_init = n_init
|
|
104
|
+
self.max_iter = max_iter
|
|
105
|
+
self.tol = tol
|
|
106
|
+
self.random_state = random_state
|
|
107
|
+
self.cluster_centers_ = None
|
|
108
|
+
self.labels_ = None
|
|
109
|
+
self.inertia_ = float('inf')
|
|
110
|
+
|
|
111
|
+
def fit(self, X: mx.array):
|
|
112
|
+
"""
|
|
113
|
+
Compute k-means clustering.
|
|
114
|
+
:param X: mx.array
|
|
115
|
+
Training instances to cluster. Must be converted to type mx.array().
|
|
116
|
+
:return:
|
|
117
|
+
"""
|
|
118
|
+
if not isinstance(X, mx.array):
|
|
119
|
+
X = mx.array(X)
|
|
120
|
+
X = X.astype(mx.float32)
|
|
121
|
+
|
|
122
|
+
if self.random_state is not None:
|
|
123
|
+
mx.random.seed(self.random_state)
|
|
124
|
+
|
|
125
|
+
N, D = X.shape
|
|
126
|
+
X_sq = mx.sum(X * X, axis=-1, keepdims=True)
|
|
127
|
+
|
|
128
|
+
# initialize randomly
|
|
129
|
+
random_indices = mx.random.randint(0, N, [self.n_clusters])
|
|
130
|
+
centers = X[random_indices]
|
|
131
|
+
|
|
132
|
+
# pre-create an array of cluster indices [0, 1, 2, ... K-1]
|
|
133
|
+
cluster_indices = mx.arange(self.n_clusters)[None, :]
|
|
134
|
+
|
|
135
|
+
for i in range(self.max_iter):
|
|
136
|
+
# distances & assignments
|
|
137
|
+
C_sq = mx.sum(centers * centers, axis=-1)
|
|
138
|
+
distances = X_sq + C_sq - 2.0 * mx.matmul(X, centers.T)
|
|
139
|
+
labels = mx.argmin(distances, axis=-1)
|
|
140
|
+
|
|
141
|
+
# one-hot encoded matrix of labels (N, K)
|
|
142
|
+
one_hot = (labels[:, None] == cluster_indices).astype(mx.float32)
|
|
143
|
+
|
|
144
|
+
# sum datapoints in each cluster (K, N) matmul (N, D) -> (K, D)
|
|
145
|
+
cluster_sums = mx.matmul(one_hot.T, X)
|
|
146
|
+
|
|
147
|
+
# count points in each cluster (K, 1)
|
|
148
|
+
cluster_counts = mx.sum(one_hot, axis=0, keepdims=True).T
|
|
149
|
+
|
|
150
|
+
# replace 0 counts with 1 to avoid NaN errors
|
|
151
|
+
safe_counts = mx.maximum(cluster_counts, 1.0)
|
|
152
|
+
new_centers = cluster_sums / safe_counts
|
|
153
|
+
|
|
154
|
+
# empty clusters - replace with random data points
|
|
155
|
+
empty_mask = (cluster_counts == 0)
|
|
156
|
+
random_replacements = X[mx.random.randint(0, N, [self.n_clusters])]
|
|
157
|
+
new_centers = mx.where(empty_mask, random_replacements, new_centers)
|
|
158
|
+
|
|
159
|
+
# convergence check
|
|
160
|
+
shift = mx.max(mx.sqrt(mx.sum((centers - new_centers) ** 2, axis=-1)))
|
|
161
|
+
centers = new_centers
|
|
162
|
+
self.labels_ = labels
|
|
163
|
+
|
|
164
|
+
mx.eval(centers, self.labels_)
|
|
165
|
+
|
|
166
|
+
if shift < self.tol:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
self.cluster_centers_ = centers
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
def predict(self, X: mx.array):
|
|
173
|
+
"""
|
|
174
|
+
Predict the closest cluster each sample in X belongs to.
|
|
175
|
+
:param X: mx.array
|
|
176
|
+
New data to predict.
|
|
177
|
+
:return:
|
|
178
|
+
"""
|
|
179
|
+
X_sq = mx.sum(X * X, axis=-1, keepdims=True)
|
|
180
|
+
C_sq = mx.sum(self.cluster_centers_ * self.cluster_centers_, axis=-1)
|
|
181
|
+
distances = X_sq + C_sq - 2.0 * mx.matmul(X, self.cluster_centers_.T)
|
|
182
|
+
return mx.argmin(distances, axis=-1)
|
|
183
|
+
|
|
184
|
+
def pairwise_distances_argmin_min(self, array: mx.array):
|
|
185
|
+
"""
|
|
186
|
+
Mimics sklearn.metrics.pairwise_distances_argmin_min.
|
|
187
|
+
:param array: mx.array
|
|
188
|
+
The original data as type mx.array() (shape: N, D)
|
|
189
|
+
"""
|
|
190
|
+
if self.cluster_centers_ is None:
|
|
191
|
+
raise ValueError("KMeans Model is not fitted yet!")
|
|
192
|
+
|
|
193
|
+
centers = self.cluster_centers_
|
|
194
|
+
|
|
195
|
+
# calculate squared distances
|
|
196
|
+
centers_sq = mx.sum(centers * centers, axis=-1, keepdims=True) # Shape (K, 1)
|
|
197
|
+
array_sq = mx.sum(array * array, axis=-1) # Shape (N,)
|
|
198
|
+
|
|
199
|
+
# distance matrix shape: (K, N)
|
|
200
|
+
distances_sq = centers_sq + array_sq - 2.0 * mx.matmul(centers, array.T)
|
|
201
|
+
|
|
202
|
+
# get the index of the minimum distance along the N dimension
|
|
203
|
+
closest_idx = mx.argmin(distances_sq, axis=-1)
|
|
204
|
+
|
|
205
|
+
# get the actual minimum distances
|
|
206
|
+
min_sq_distances = mx.min(distances_sq, axis=-1)
|
|
207
|
+
|
|
208
|
+
# square root for true Euclidean distance
|
|
209
|
+
min_distances = mx.sqrt(mx.maximum(min_sq_distances, 0.0))
|
|
210
|
+
|
|
211
|
+
return closest_idx, min_distances
|
|
212
|
+
|
|
213
|
+
def _kmeans_plusplus(self, X, N):
|
|
214
|
+
"""
|
|
215
|
+
Support function for k-means++.
|
|
216
|
+
:param X:
|
|
217
|
+
:param N:
|
|
218
|
+
:return:
|
|
219
|
+
"""
|
|
220
|
+
centers = []
|
|
221
|
+
first_idx = mx.random.randint(0, N, [1]).item()
|
|
222
|
+
centers.append(X[first_idx])
|
|
223
|
+
X_sq = mx.sum(X * X, axis=-1, keepdims=True)
|
|
224
|
+
|
|
225
|
+
for _ in range(1, self.n_clusters):
|
|
226
|
+
current_centers = mx.stack(centers)
|
|
227
|
+
C_sq = mx.sum(current_centers * current_centers, axis=-1)
|
|
228
|
+
distances = X_sq + C_sq - 2.0 * mx.matmul(X, current_centers.T)
|
|
229
|
+
min_distances = mx.min(distances, axis=-1)
|
|
230
|
+
|
|
231
|
+
logits = mx.log(min_distances + 1e-8)
|
|
232
|
+
next_idx = mx.random.categorical(logits, shape=[1]).item()
|
|
233
|
+
centers.append(X[next_idx])
|
|
234
|
+
|
|
235
|
+
return mx.stack(centers)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
# todo test DBSCAN
|
|
239
|
+
class DBSCAN:
|
|
240
|
+
def __init__(self, eps=0.5, min_samples=5, metric="euclidean", chunk_size=5_000):
|
|
241
|
+
self.eps = eps
|
|
242
|
+
self.min_samples = min_samples
|
|
243
|
+
self.metric = metric
|
|
244
|
+
self.chunk_size = chunk_size
|
|
245
|
+
self.labels_ = None
|
|
246
|
+
|
|
247
|
+
def fit(self, X: mx.array):
|
|
248
|
+
n_samples = X.shape[0]
|
|
249
|
+
# distance and masking
|
|
250
|
+
dist_matrix = self._compute_distances(X)
|
|
251
|
+
adj_matrix = dist_matrix <= self.eps # mx.array (bool)
|
|
252
|
+
|
|
253
|
+
# core point detection
|
|
254
|
+
neighbor_counts = mx.sum(adj_matrix, axis=1)
|
|
255
|
+
is_core = neighbor_counts >= self.min_samples
|
|
256
|
+
|
|
257
|
+
# convert to np.array for calculations
|
|
258
|
+
adj_np = np.array(adj_matrix)
|
|
259
|
+
is_core_np = np.array(is_core)
|
|
260
|
+
self.labels_ = np.full(n_samples, -1)
|
|
261
|
+
|
|
262
|
+
cluster_id = 0
|
|
263
|
+
for i in range(n_samples):
|
|
264
|
+
if self.labels_[i] != -1 or not is_core_np[i]:
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
self.labels_[i] = cluster_id
|
|
268
|
+
stack = [i]
|
|
269
|
+
while stack:
|
|
270
|
+
curr = stack.pop()
|
|
271
|
+
# get neighbors usign boolean mask
|
|
272
|
+
neighbors = np.where(adj_np[curr])[0]
|
|
273
|
+
for neighbor in neighbors:
|
|
274
|
+
if self.labels_[neighbor] == -1:
|
|
275
|
+
self.labels_[neighbor] = cluster_id
|
|
276
|
+
if is_core_np[neighbor]:
|
|
277
|
+
stack.append(neighbor)
|
|
278
|
+
cluster_id += 1
|
|
279
|
+
return self
|
|
280
|
+
|
|
281
|
+
def _compute_distances(self, X: mx.array):
|
|
282
|
+
"""vectorized distance calculation on mlx."""
|
|
283
|
+
if self.metric == "tanimoto":
|
|
284
|
+
return get_tanimoto(fps=X, chunk_size=self.chunk_size, matrix=True)
|
|
285
|
+
elif self.metric == "euclidean":
|
|
286
|
+
# optimized L2: sqrt(sum(x^2) + sum(y^2) - 2 * x.T * y)
|
|
287
|
+
sq_norms = mx.sum(X ** 2, axis=1, keepdims=True)
|
|
288
|
+
dist_sq = sq_norms + sq_norms.T - 2 * mx.matmul(X, X.T)
|
|
289
|
+
return mx.sqrt(mx.maximum(dist_sq, 0.0))
|
|
290
|
+
elif self.metric == "manhattan" or (self.metric == "minkowski" and self.p == 1):
|
|
291
|
+
# L1 Distance
|
|
292
|
+
return mx.sum(mx.abs(X[:, None, :] - X[None, :, :]), axis=-1)
|
|
293
|
+
elif self.metric == "cosine":
|
|
294
|
+
# cosine Distance = 1 - (A·B / (||A||*||B||))
|
|
295
|
+
norm = mx.sqrt(mx.sum(X ** 2, axis=1, keepdims=True))
|
|
296
|
+
similarity = mx.matmul(X, X.T) / (norm * norm.T + 1e-7)
|
|
297
|
+
return 1.0 - similarity
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Metric '{self.metric}' is not supported in this MLX implementation.")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# class MLXSpectralClustering:
|
|
303
|
+
# def __init__(self, n_clusters=8, gamma=1.0, affinity='rbf', assign_labels='kmeans'):
|
|
304
|
+
# self.n_clusters = n_clusters
|
|
305
|
+
# self.gamma = gamma
|
|
306
|
+
# self.affinity = affinity
|
|
307
|
+
# self.assign_labels = assign_labels
|
|
308
|
+
# self.labels_ = None
|
|
309
|
+
#
|
|
310
|
+
# def fit_predict(self, X):
|
|
311
|
+
# N = X.shape[0]
|
|
312
|
+
#
|
|
313
|
+
# # 1. Compute Affinity Matrix (RBF Kernel)
|
|
314
|
+
# # Using the same logic as sklearn: exp(-gamma * ||x-y||^2)
|
|
315
|
+
# sq_norms = mx.sum(X ** 2, axis=1)
|
|
316
|
+
# dist_sq = sq_norms[:, None] + sq_norms[None, :] - 2 * mx.matmul(X, X.T)
|
|
317
|
+
# A = mx.exp(-self.gamma * dist_sq)
|
|
318
|
+
#
|
|
319
|
+
# # 2. Compute Degree Matrix and Laplacian
|
|
320
|
+
# # L = D - A (Unnormalized) or L = I - D^-1/2 A D^-1/2 (Normalized)
|
|
321
|
+
# D = mx.sum(A, axis=1)
|
|
322
|
+
# D_inv_sqrt = 1.0 / mx.sqrt(D)
|
|
323
|
+
# L_norm = mx.eye(N) - (D_inv_sqrt[:, None] * A * D_inv_sqrt[None, :])
|
|
324
|
+
#
|
|
325
|
+
# # 3. Eigen Decomposition
|
|
326
|
+
# # We need the eigenvectors corresponding to the smallest eigenvalues
|
|
327
|
+
# evals, evecs = mx.linalg.eigh(L_norm)
|
|
328
|
+
#
|
|
329
|
+
# # 4. Extract Top K Eigenvectors (Spectral Embedding)
|
|
330
|
+
# U = evecs[:, :self.n_clusters]
|
|
331
|
+
#
|
|
332
|
+
# # Normalize rows to unit length (important for stability)
|
|
333
|
+
# U = U / mx.linalg.norm(U, axis=1, keepdims=True)
|
|
334
|
+
#
|
|
335
|
+
# # 5. Final Step: Run your existing KMeans on the embedding
|
|
336
|
+
# from your_kmeans_file import KMeans # Use your existing class here
|
|
337
|
+
# km = KMeans(n_clusters=self.n_clusters)
|
|
338
|
+
# self.labels_ = km.fit(U).labels_
|
|
339
|
+
#
|
|
340
|
+
# return self.labels_
|
|
341
|
+
#
|
|
342
|
+
# class MLXGaussianMixture:
|
|
343
|
+
# def __init__(self, n_components=1, tol=1e-3, max_iter=100, reg_covar=1e-6):
|
|
344
|
+
# self.n_components = n_components
|
|
345
|
+
# self.tol = tol
|
|
346
|
+
# self.max_iter = max_iter
|
|
347
|
+
# self.reg_covar = reg_covar # Matches sklearn's stability constant
|
|
348
|
+
#
|
|
349
|
+
# self.weights_ = None
|
|
350
|
+
# self.means_ = None
|
|
351
|
+
# self.covariances_ = None
|
|
352
|
+
#
|
|
353
|
+
# def fit(self, X):
|
|
354
|
+
# N, D = X.shape
|
|
355
|
+
# # Initialize weights uniformly and means randomly from data
|
|
356
|
+
# self.weights_ = mx.full((self.n_components,), 1.0 / self.n_components)
|
|
357
|
+
# self.means_ = X[mx.random.randint(0, N, (self.n_components,))]
|
|
358
|
+
# self.covariances_ = mx.stack([mx.eye(D) for _ in range(self.n_components)])
|
|
359
|
+
#
|
|
360
|
+
# prev_log_likelihood = -float('inf')
|
|
361
|
+
#
|
|
362
|
+
# for i in range(self.max_iter):
|
|
363
|
+
# # --- E-Step: Compute Responsibilities ---
|
|
364
|
+
# resp = self._estimate_responsibilities(X)
|
|
365
|
+
#
|
|
366
|
+
# # --- M-Step: Update Parameters ---
|
|
367
|
+
# nk = mx.sum(resp, axis=0) # Total weight in each cluster
|
|
368
|
+
# self.weights_ = nk / N
|
|
369
|
+
# self.means_ = mx.matmul(resp.T, X) / nk[:, None]
|
|
370
|
+
#
|
|
371
|
+
# for k in range(self.n_components):
|
|
372
|
+
# diff = X - self.means_[k]
|
|
373
|
+
# weighted_diff = diff * mx.sqrt(resp[:, k:k + 1])
|
|
374
|
+
# # Add regularization to diagonal for numerical stability
|
|
375
|
+
# self.covariances_[k] = mx.matmul(weighted_diff.T, weighted_diff) / nk[k] + \
|
|
376
|
+
# mx.eye(D) * self.reg_covar
|
|
377
|
+
#
|
|
378
|
+
# # Convergence Check
|
|
379
|
+
# current_log_likelihood = self._compute_log_likelihood(X)
|
|
380
|
+
# if abs(current_log_likelihood - prev_log_likelihood) < self.tol:
|
|
381
|
+
# break
|
|
382
|
+
# prev_log_likelihood = current_log_likelihood
|
|
383
|
+
# mx.eval(self.means_, self.covariances_)
|
|
384
|
+
#
|
|
385
|
+
# return self
|
|
386
|
+
#
|
|
387
|
+
# def _estimate_responsibilities(self, X):
|
|
388
|
+
# # Calculates P(cluster | point) using log-sum-exp for stability
|
|
389
|
+
# weighted_log_probs = self._compute_log_prob(X) + mx.log(self.weights_)
|
|
390
|
+
# log_prob_norm = mx.logsumexp(weighted_log_probs, axis=1, keepdims=True)
|
|
391
|
+
# return mx.exp(weighted_log_probs - log_prob_norm)
|
|
392
|
+
#
|
|
393
|
+
# def _compute_log_prob(self, X):
|
|
394
|
+
# # Vectorized multivariate normal log-pdf
|
|
395
|
+
# N, D = X.shape
|
|
396
|
+
# probs = []
|
|
397
|
+
# for k in range(self.n_components):
|
|
398
|
+
# diff = X - self.means_[k]
|
|
399
|
+
# # MLX handles linalg.inv and det very efficiently on GPU
|
|
400
|
+
# prec = mx.linalg.inv(self.covariances_[k])
|
|
401
|
+
# log_det = mx.log(mx.linalg.det(self.covariances_[k]))
|
|
402
|
+
#
|
|
403
|
+
# # Log Mahalanobis distance
|
|
404
|
+
# dist = mx.sum(mx.matmul(diff, prec) * diff, axis=1)
|
|
405
|
+
# log_prob = -0.5 * (D * mx.log(2 * 3.14159) + log_det + dist)
|
|
406
|
+
# probs.append(log_prob)
|
|
407
|
+
# return mx.stack(probs, axis=1)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
if __name__ == "__main__":
|
|
411
|
+
import doctest
|
|
412
|
+
|
|
413
|
+
doctest.testmod()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlx-mol-cluster"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = ""
|
|
5
|
+
authors = [{ name = "Tony Eight Lin" }]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
readme = "README.md"
|
|
8
|
+
requires-python = ">=3.11"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"mlx (>=0.31.0,<0.32.0)",
|
|
11
|
+
"numpy (>=2.4.2,<3.0.0)",
|
|
12
|
+
"pandas (>=3.0.1,<4.0.0)",
|
|
13
|
+
"seaborn (>=0.13.2,<0.14.0)",
|
|
14
|
+
"rdkit (>=2025.9.5,<2026.0.0)",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.poetry]
|
|
18
|
+
packages = [{ include = "mlx_cluster" }]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
23
|
+
build-backend = "poetry.core.masonry.api"
|
|
24
|
+
|
|
25
|
+
[dependency-groups]
|
|
26
|
+
dev = [
|
|
27
|
+
"jupyter (>=1.1.1,<2.0.0)",
|
|
28
|
+
"ipywidgets (>=8.1.8,<9.0.0)",
|
|
29
|
+
"bblean (>=0.10.1,<0.11.0)",
|
|
30
|
+
"mols2grid (>=2.2.0,<3.0.0)",
|
|
31
|
+
"scikit-learn (>=1.8.0,<2.0.0)"
|
|
32
|
+
]
|