mlx-mol-cluster 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Tony E. Lin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,65 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-mol-cluster
3
+ Version: 0.1.0
4
+ Summary:
5
+ License-Expression: MIT
6
+ License-File: LICENSE
7
+ Author: Tony Eight Lin
8
+ Requires-Python: >=3.11
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Classifier: Programming Language :: Python :: 3.14
14
+ Requires-Dist: mlx (>=0.31.0,<0.32.0)
15
+ Requires-Dist: numpy (>=2.4.2,<3.0.0)
16
+ Requires-Dist: pandas (>=3.0.1,<4.0.0)
17
+ Requires-Dist: rdkit (>=2025.9.5,<2026.0.0)
18
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
19
+ Description-Content-Type: text/markdown
20
+
21
+ [![Python Versions](https://img.shields.io/badge/python-3.11+-blue.svg?logo=python&logoColor=white)](https://github.com/tlint101/MLXMolCluster)
22
+ [![MLX](https://img.shields.io/badge/MLX-0.31.0+-black?logo=apple&labelColor=gray)](https://github.com/ml-explore/mlx)
23
+
24
+ # MLXMolCluster
25
+
26
+ Leverage Apple Silicon to cluster molecules using MLX.
27
+
28
+ At the time of writing, the project contains two clustering methods:
29
+ - Butina
30
+ - KMeans
31
+
32
+ Additional clustering methods will be added over time.
33
+
34
+ Examples have been written and can be found [here](tutorial).
35
+
36
+ ## Installation
37
+ Clone and install locally:
38
+ ```python
39
+ pip install git+https://github.com/tlint101/MLXMolCluster.git
40
+ ```
41
+
42
+ ## Example
43
+ The following is an example of clustering molecules using Butina on MLX.
44
+ ```python
45
+ # generate molecular fingerprints
46
+ fp_gen = rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=1024)
47
+ rdkit_fps = [fp_gen.GetFingerprint(mol) for mol in mol_list]
48
+
49
+ # convert to mlx arrays
50
+ mlx_fp = fp_to_mlx(rdkit_fps)
51
+
52
+ # Butina cluster
53
+ butina_mlx = butina(mlx_fp)
54
+ ```
55
+
56
+ A speed comparison can be seen at the [tutorial section](tutorial). The runs were performed on a M2 Pro chip
57
+ (10 CPU, 16 GPU)
58
+ ![Comparisons of Clustering Methods](img/cluster_compare.png "Clustering 10,000 Molecules")
59
+
60
+ **NOTE:** The figure can be misleading. The figure shows the clustering speed of already generated molecular
61
+ fingerprints. The main bottleneck of clustering remains on the generation of molecular fingerprints. This is done on the
62
+ CPU before being converted to the GPU. Depending on the number of molecules, this can be time intensive.
63
+
64
+ ### Additional
65
+ Collaborations are welcome!
@@ -0,0 +1,45 @@
1
+ [![Python Versions](https://img.shields.io/badge/python-3.11+-blue.svg?logo=python&logoColor=white)](https://github.com/tlint101/MLXMolCluster)
2
+ [![MLX](https://img.shields.io/badge/MLX-0.31.0+-black?logo=apple&labelColor=gray)](https://github.com/ml-explore/mlx)
3
+
4
+ # MLXMolCluster
5
+
6
+ Leverage Apple Silicon to cluster molecules using MLX.
7
+
8
+ At the time of writing, the project contains two clustering methods:
9
+ - Butina
10
+ - KMeans
11
+
12
+ Additional clustering methods will be added over time.
13
+
14
+ Examples have been written and can be found [here](tutorial).
15
+
16
+ ## Installation
17
+ Clone and install locally:
18
+ ```python
19
+ pip install git+https://github.com/tlint101/MLXMolCluster.git
20
+ ```
21
+
22
+ ## Example
23
+ The following is an example of clustering molecules using Butina on MLX.
24
+ ```python
25
+ # generate molecular fingerprints
26
+ fp_gen = rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=1024)
27
+ rdkit_fps = [fp_gen.GetFingerprint(mol) for mol in mol_list]
28
+
29
+ # convert to mlx arrays
30
+ mlx_fp = fp_to_mlx(rdkit_fps)
31
+
32
+ # Butina cluster
33
+ butina_mlx = butina(mlx_fp)
34
+ ```
35
+
36
+ A speed comparison can be seen at the [tutorial section](tutorial). The runs were performed on a M2 Pro chip
37
+ (10 CPU, 16 GPU)
38
+ ![Comparisons of Clustering Methods](img/cluster_compare.png "Clustering 10,000 Molecules")
39
+
40
+ **NOTE:** The figure can be misleading. The figure shows the clustering speed of already generated molecular
41
+ fingerprints. The main bottleneck of clustering remains on the generation of molecular fingerprints. This is done on the
42
+ CPU before being converted to the GPU. Depending on the number of molecules, this can be time intensive.
43
+
44
+ ### Additional
45
+ Collaborations are welcome!
@@ -0,0 +1 @@
1
+ from .cluster import *
@@ -0,0 +1,413 @@
1
+ import numpy as np
2
+ import mlx.core as mx
3
+ from rdkit.ML.Cluster import Butina
4
+ from typing import Optional
5
+
6
+
7
+ def fp_to_mlx(fp: list) -> mx.array:
8
+ """
9
+ Convert a list of fingerprints into MLX array
10
+ :param fp: list
11
+ A list of molecular fingerprints calcualted using RDKit.
12
+ :return:
13
+ """
14
+ # convert list to array then to mx.array
15
+ arr = np.array(fp)
16
+ fp_array = mx.array(arr).astype(mx.float32)
17
+ return fp_array
18
+
19
+
20
+ def get_tanimoto(fps: Optional[mx.array] = None, chunk_size: int = 5000, matrix: bool = False):
21
+ """"
22
+ Calculate Tanimoto similarity score between an mx.array of molecules.
23
+ fps: Optional[mx.array]
24
+ An mx.array of molecular fingerprints.
25
+ chunk_size: int
26
+ The number of chunks to process keep under GPU buffer limits.
27
+ matrix: bool
28
+ Whether to output a matrix or a flattened np.array.
29
+ """
30
+ n = fps.shape[0]
31
+ bits_set = mx.sum(fps, axis=1, keepdims=True)
32
+
33
+ results = []
34
+
35
+ # process by chunks
36
+ for i in range(0, n, chunk_size):
37
+ end_i = min(i + chunk_size, n)
38
+ chunk_fps = fps[i:end_i] # shape: (chunk, bits)
39
+ chunk_bits = bits_set[i:end_i] # shape: (chunk, 1)
40
+
41
+ # computer intersection
42
+ intersections = mx.matmul(chunk_fps, fps.T)
43
+
44
+ # Tanimoto calc
45
+ union = chunk_bits + bits_set.T - intersections
46
+ tanimoto_sim = intersections / (union + 1e-7)
47
+ dist_chunk = 1.0 - tanimoto_sim
48
+
49
+ if matrix is not True:
50
+ # only keep lower triangle
51
+ for row_idx in range(i, end_i):
52
+ # row in the distance chunk is (row_idx - i)
53
+ actual_row = dist_chunk[row_idx - i, :row_idx]
54
+ results.append(np.array(actual_row)) # move to CPU/NumPy
55
+ else:
56
+ results.append(dist_chunk)
57
+
58
+ # convert into numpy array
59
+ return np.concatenate(results)
60
+
61
+
62
+ def butina(fingerprints: mx.array = None, cutoff: float = 0.2, chunk_size: int = 5000) -> list:
63
+ """
64
+ Cluster fingerprints.
65
+ :param fingerprints: mx.array
66
+ A list of RDKit molecular fingerprints.
67
+ :param cutoff: float
68
+ Set the cluster threshold.
69
+ :param chunk_size: int
70
+ The number of chunks to process keep under GPU buffer limits.
71
+ :return:
72
+ """
73
+ # tanimoto matrix
74
+ distance_matrix = get_tanimoto(fingerprints, chunk_size=chunk_size, matrix=False)
75
+ # cluster
76
+ clusters = Butina.ClusterData(distance_matrix, len(fingerprints), cutoff, isDistData=True)
77
+ clusters = sorted(clusters, key=len, reverse=True)
78
+ return clusters
79
+
80
+
81
+ class KMeans:
82
+ def __init__(self, n_clusters: int = 8, init: str = 'k-means++', n_init: int = 1, max_iter: int = 300,
83
+ tol: float = 1e-4, random_state: int = None):
84
+ """
85
+ Initialize the KMeans object. Params should be similar to what can be found on SKlearn:
86
+ https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html
87
+ :param n_clusters: int
88
+ Number of clusters and centroids to generate.
89
+ :param init: str
90
+ Method of initialization. At this moment, only 'k-means++' is supported.
91
+ :param n_init: int
92
+ Number of times the k-means algorithm is run with different centroid seeds.
93
+ :param max_iter: int
94
+ Maximum number of iterations of the k-means algorithm for a single run.
95
+ :param tol: float
96
+ Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two
97
+ consecutive iterations to declare convergence.
98
+ :param random_state: int
99
+ Determines random number generation for centroid initialization.
100
+ """
101
+ self.n_clusters = n_clusters
102
+ self.init = init
103
+ self.n_init = n_init
104
+ self.max_iter = max_iter
105
+ self.tol = tol
106
+ self.random_state = random_state
107
+ self.cluster_centers_ = None
108
+ self.labels_ = None
109
+ self.inertia_ = float('inf')
110
+
111
+ def fit(self, X: mx.array):
112
+ """
113
+ Compute k-means clustering.
114
+ :param X: mx.array
115
+ Training instances to cluster. Must be converted to type mx.array().
116
+ :return:
117
+ """
118
+ if not isinstance(X, mx.array):
119
+ X = mx.array(X)
120
+ X = X.astype(mx.float32)
121
+
122
+ if self.random_state is not None:
123
+ mx.random.seed(self.random_state)
124
+
125
+ N, D = X.shape
126
+ X_sq = mx.sum(X * X, axis=-1, keepdims=True)
127
+
128
+ # initialize randomly
129
+ random_indices = mx.random.randint(0, N, [self.n_clusters])
130
+ centers = X[random_indices]
131
+
132
+ # pre-create an array of cluster indices [0, 1, 2, ... K-1]
133
+ cluster_indices = mx.arange(self.n_clusters)[None, :]
134
+
135
+ for i in range(self.max_iter):
136
+ # distances & assignments
137
+ C_sq = mx.sum(centers * centers, axis=-1)
138
+ distances = X_sq + C_sq - 2.0 * mx.matmul(X, centers.T)
139
+ labels = mx.argmin(distances, axis=-1)
140
+
141
+ # one-hot encoded matrix of labels (N, K)
142
+ one_hot = (labels[:, None] == cluster_indices).astype(mx.float32)
143
+
144
+ # sum datapoints in each cluster (K, N) matmul (N, D) -> (K, D)
145
+ cluster_sums = mx.matmul(one_hot.T, X)
146
+
147
+ # count points in each cluster (K, 1)
148
+ cluster_counts = mx.sum(one_hot, axis=0, keepdims=True).T
149
+
150
+ # replace 0 counts with 1 to avoid NaN errors
151
+ safe_counts = mx.maximum(cluster_counts, 1.0)
152
+ new_centers = cluster_sums / safe_counts
153
+
154
+ # empty clusters - replace with random data points
155
+ empty_mask = (cluster_counts == 0)
156
+ random_replacements = X[mx.random.randint(0, N, [self.n_clusters])]
157
+ new_centers = mx.where(empty_mask, random_replacements, new_centers)
158
+
159
+ # convergence check
160
+ shift = mx.max(mx.sqrt(mx.sum((centers - new_centers) ** 2, axis=-1)))
161
+ centers = new_centers
162
+ self.labels_ = labels
163
+
164
+ mx.eval(centers, self.labels_)
165
+
166
+ if shift < self.tol:
167
+ break
168
+
169
+ self.cluster_centers_ = centers
170
+ return self
171
+
172
+ def predict(self, X: mx.array):
173
+ """
174
+ Predict the closest cluster each sample in X belongs to.
175
+ :param X: mx.array
176
+ New data to predict.
177
+ :return:
178
+ """
179
+ X_sq = mx.sum(X * X, axis=-1, keepdims=True)
180
+ C_sq = mx.sum(self.cluster_centers_ * self.cluster_centers_, axis=-1)
181
+ distances = X_sq + C_sq - 2.0 * mx.matmul(X, self.cluster_centers_.T)
182
+ return mx.argmin(distances, axis=-1)
183
+
184
+ def pairwise_distances_argmin_min(self, array: mx.array):
185
+ """
186
+ Mimics sklearn.metrics.pairwise_distances_argmin_min.
187
+ :param array: mx.array
188
+ The original data as type mx.array() (shape: N, D)
189
+ """
190
+ if self.cluster_centers_ is None:
191
+ raise ValueError("KMeans Model is not fitted yet!")
192
+
193
+ centers = self.cluster_centers_
194
+
195
+ # calculate squared distances
196
+ centers_sq = mx.sum(centers * centers, axis=-1, keepdims=True) # Shape (K, 1)
197
+ array_sq = mx.sum(array * array, axis=-1) # Shape (N,)
198
+
199
+ # distance matrix shape: (K, N)
200
+ distances_sq = centers_sq + array_sq - 2.0 * mx.matmul(centers, array.T)
201
+
202
+ # get the index of the minimum distance along the N dimension
203
+ closest_idx = mx.argmin(distances_sq, axis=-1)
204
+
205
+ # get the actual minimum distances
206
+ min_sq_distances = mx.min(distances_sq, axis=-1)
207
+
208
+ # square root for true Euclidean distance
209
+ min_distances = mx.sqrt(mx.maximum(min_sq_distances, 0.0))
210
+
211
+ return closest_idx, min_distances
212
+
213
+ def _kmeans_plusplus(self, X, N):
214
+ """
215
+ Support function for k-means++.
216
+ :param X:
217
+ :param N:
218
+ :return:
219
+ """
220
+ centers = []
221
+ first_idx = mx.random.randint(0, N, [1]).item()
222
+ centers.append(X[first_idx])
223
+ X_sq = mx.sum(X * X, axis=-1, keepdims=True)
224
+
225
+ for _ in range(1, self.n_clusters):
226
+ current_centers = mx.stack(centers)
227
+ C_sq = mx.sum(current_centers * current_centers, axis=-1)
228
+ distances = X_sq + C_sq - 2.0 * mx.matmul(X, current_centers.T)
229
+ min_distances = mx.min(distances, axis=-1)
230
+
231
+ logits = mx.log(min_distances + 1e-8)
232
+ next_idx = mx.random.categorical(logits, shape=[1]).item()
233
+ centers.append(X[next_idx])
234
+
235
+ return mx.stack(centers)
236
+
237
+
238
+ # todo test DBSCAN
239
+ class DBSCAN:
240
+ def __init__(self, eps=0.5, min_samples=5, metric="euclidean", chunk_size=5_000):
241
+ self.eps = eps
242
+ self.min_samples = min_samples
243
+ self.metric = metric
244
+ self.chunk_size = chunk_size
245
+ self.labels_ = None
246
+
247
+ def fit(self, X: mx.array):
248
+ n_samples = X.shape[0]
249
+ # distance and masking
250
+ dist_matrix = self._compute_distances(X)
251
+ adj_matrix = dist_matrix <= self.eps # mx.array (bool)
252
+
253
+ # core point detection
254
+ neighbor_counts = mx.sum(adj_matrix, axis=1)
255
+ is_core = neighbor_counts >= self.min_samples
256
+
257
+ # convert to np.array for calculations
258
+ adj_np = np.array(adj_matrix)
259
+ is_core_np = np.array(is_core)
260
+ self.labels_ = np.full(n_samples, -1)
261
+
262
+ cluster_id = 0
263
+ for i in range(n_samples):
264
+ if self.labels_[i] != -1 or not is_core_np[i]:
265
+ continue
266
+
267
+ self.labels_[i] = cluster_id
268
+ stack = [i]
269
+ while stack:
270
+ curr = stack.pop()
271
+ # get neighbors usign boolean mask
272
+ neighbors = np.where(adj_np[curr])[0]
273
+ for neighbor in neighbors:
274
+ if self.labels_[neighbor] == -1:
275
+ self.labels_[neighbor] = cluster_id
276
+ if is_core_np[neighbor]:
277
+ stack.append(neighbor)
278
+ cluster_id += 1
279
+ return self
280
+
281
+ def _compute_distances(self, X: mx.array):
282
+ """vectorized distance calculation on mlx."""
283
+ if self.metric == "tanimoto":
284
+ return get_tanimoto(fps=X, chunk_size=self.chunk_size, matrix=True)
285
+ elif self.metric == "euclidean":
286
+ # optimized L2: sqrt(sum(x^2) + sum(y^2) - 2 * x.T * y)
287
+ sq_norms = mx.sum(X ** 2, axis=1, keepdims=True)
288
+ dist_sq = sq_norms + sq_norms.T - 2 * mx.matmul(X, X.T)
289
+ return mx.sqrt(mx.maximum(dist_sq, 0.0))
290
+ elif self.metric == "manhattan" or (self.metric == "minkowski" and self.p == 1):
291
+ # L1 Distance
292
+ return mx.sum(mx.abs(X[:, None, :] - X[None, :, :]), axis=-1)
293
+ elif self.metric == "cosine":
294
+ # cosine Distance = 1 - (A·B / (||A||*||B||))
295
+ norm = mx.sqrt(mx.sum(X ** 2, axis=1, keepdims=True))
296
+ similarity = mx.matmul(X, X.T) / (norm * norm.T + 1e-7)
297
+ return 1.0 - similarity
298
+ else:
299
+ raise ValueError(f"Metric '{self.metric}' is not supported in this MLX implementation.")
300
+
301
+
302
+ # class MLXSpectralClustering:
303
+ # def __init__(self, n_clusters=8, gamma=1.0, affinity='rbf', assign_labels='kmeans'):
304
+ # self.n_clusters = n_clusters
305
+ # self.gamma = gamma
306
+ # self.affinity = affinity
307
+ # self.assign_labels = assign_labels
308
+ # self.labels_ = None
309
+ #
310
+ # def fit_predict(self, X):
311
+ # N = X.shape[0]
312
+ #
313
+ # # 1. Compute Affinity Matrix (RBF Kernel)
314
+ # # Using the same logic as sklearn: exp(-gamma * ||x-y||^2)
315
+ # sq_norms = mx.sum(X ** 2, axis=1)
316
+ # dist_sq = sq_norms[:, None] + sq_norms[None, :] - 2 * mx.matmul(X, X.T)
317
+ # A = mx.exp(-self.gamma * dist_sq)
318
+ #
319
+ # # 2. Compute Degree Matrix and Laplacian
320
+ # # L = D - A (Unnormalized) or L = I - D^-1/2 A D^-1/2 (Normalized)
321
+ # D = mx.sum(A, axis=1)
322
+ # D_inv_sqrt = 1.0 / mx.sqrt(D)
323
+ # L_norm = mx.eye(N) - (D_inv_sqrt[:, None] * A * D_inv_sqrt[None, :])
324
+ #
325
+ # # 3. Eigen Decomposition
326
+ # # We need the eigenvectors corresponding to the smallest eigenvalues
327
+ # evals, evecs = mx.linalg.eigh(L_norm)
328
+ #
329
+ # # 4. Extract Top K Eigenvectors (Spectral Embedding)
330
+ # U = evecs[:, :self.n_clusters]
331
+ #
332
+ # # Normalize rows to unit length (important for stability)
333
+ # U = U / mx.linalg.norm(U, axis=1, keepdims=True)
334
+ #
335
+ # # 5. Final Step: Run your existing KMeans on the embedding
336
+ # from your_kmeans_file import KMeans # Use your existing class here
337
+ # km = KMeans(n_clusters=self.n_clusters)
338
+ # self.labels_ = km.fit(U).labels_
339
+ #
340
+ # return self.labels_
341
+ #
342
+ # class MLXGaussianMixture:
343
+ # def __init__(self, n_components=1, tol=1e-3, max_iter=100, reg_covar=1e-6):
344
+ # self.n_components = n_components
345
+ # self.tol = tol
346
+ # self.max_iter = max_iter
347
+ # self.reg_covar = reg_covar # Matches sklearn's stability constant
348
+ #
349
+ # self.weights_ = None
350
+ # self.means_ = None
351
+ # self.covariances_ = None
352
+ #
353
+ # def fit(self, X):
354
+ # N, D = X.shape
355
+ # # Initialize weights uniformly and means randomly from data
356
+ # self.weights_ = mx.full((self.n_components,), 1.0 / self.n_components)
357
+ # self.means_ = X[mx.random.randint(0, N, (self.n_components,))]
358
+ # self.covariances_ = mx.stack([mx.eye(D) for _ in range(self.n_components)])
359
+ #
360
+ # prev_log_likelihood = -float('inf')
361
+ #
362
+ # for i in range(self.max_iter):
363
+ # # --- E-Step: Compute Responsibilities ---
364
+ # resp = self._estimate_responsibilities(X)
365
+ #
366
+ # # --- M-Step: Update Parameters ---
367
+ # nk = mx.sum(resp, axis=0) # Total weight in each cluster
368
+ # self.weights_ = nk / N
369
+ # self.means_ = mx.matmul(resp.T, X) / nk[:, None]
370
+ #
371
+ # for k in range(self.n_components):
372
+ # diff = X - self.means_[k]
373
+ # weighted_diff = diff * mx.sqrt(resp[:, k:k + 1])
374
+ # # Add regularization to diagonal for numerical stability
375
+ # self.covariances_[k] = mx.matmul(weighted_diff.T, weighted_diff) / nk[k] + \
376
+ # mx.eye(D) * self.reg_covar
377
+ #
378
+ # # Convergence Check
379
+ # current_log_likelihood = self._compute_log_likelihood(X)
380
+ # if abs(current_log_likelihood - prev_log_likelihood) < self.tol:
381
+ # break
382
+ # prev_log_likelihood = current_log_likelihood
383
+ # mx.eval(self.means_, self.covariances_)
384
+ #
385
+ # return self
386
+ #
387
+ # def _estimate_responsibilities(self, X):
388
+ # # Calculates P(cluster | point) using log-sum-exp for stability
389
+ # weighted_log_probs = self._compute_log_prob(X) + mx.log(self.weights_)
390
+ # log_prob_norm = mx.logsumexp(weighted_log_probs, axis=1, keepdims=True)
391
+ # return mx.exp(weighted_log_probs - log_prob_norm)
392
+ #
393
+ # def _compute_log_prob(self, X):
394
+ # # Vectorized multivariate normal log-pdf
395
+ # N, D = X.shape
396
+ # probs = []
397
+ # for k in range(self.n_components):
398
+ # diff = X - self.means_[k]
399
+ # # MLX handles linalg.inv and det very efficiently on GPU
400
+ # prec = mx.linalg.inv(self.covariances_[k])
401
+ # log_det = mx.log(mx.linalg.det(self.covariances_[k]))
402
+ #
403
+ # # Log Mahalanobis distance
404
+ # dist = mx.sum(mx.matmul(diff, prec) * diff, axis=1)
405
+ # log_prob = -0.5 * (D * mx.log(2 * 3.14159) + log_det + dist)
406
+ # probs.append(log_prob)
407
+ # return mx.stack(probs, axis=1)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ import doctest
412
+
413
+ doctest.testmod()
@@ -0,0 +1,32 @@
1
+ [project]
2
+ name = "mlx-mol-cluster"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = [{ name = "Tony Eight Lin" }]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ requires-python = ">=3.11"
9
+ dependencies = [
10
+ "mlx (>=0.31.0,<0.32.0)",
11
+ "numpy (>=2.4.2,<3.0.0)",
12
+ "pandas (>=3.0.1,<4.0.0)",
13
+ "seaborn (>=0.13.2,<0.14.0)",
14
+ "rdkit (>=2025.9.5,<2026.0.0)",
15
+ ]
16
+
17
+ [tool.poetry]
18
+ packages = [{ include = "mlx_cluster" }]
19
+
20
+
21
+ [build-system]
22
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
23
+ build-backend = "poetry.core.masonry.api"
24
+
25
+ [dependency-groups]
26
+ dev = [
27
+ "jupyter (>=1.1.1,<2.0.0)",
28
+ "ipywidgets (>=8.1.8,<9.0.0)",
29
+ "bblean (>=0.10.1,<0.11.0)",
30
+ "mols2grid (>=2.2.0,<3.0.0)",
31
+ "scikit-learn (>=1.8.0,<2.0.0)"
32
+ ]