bblean 0.6.0b1__cp312-cp312-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bblean/plotting.py ADDED
@@ -0,0 +1,479 @@
1
+ r"""Plotting and visualization convenience functions"""
2
+
3
+ import warnings
4
+ from pathlib import Path
5
+ import pickle
6
+ import random
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib as mpl
12
+ import seaborn as sns
13
+ from rdkit import Chem
14
+ from rdkit.Chem import Draw
15
+ import colorcet
16
+ from sklearn.preprocessing import StandardScaler, normalize as normalize_features
17
+ from sklearn.decomposition import PCA
18
+ from openTSNE.sklearn import TSNE
19
+ from openTSNE.affinity import Multiscale
20
+ import umap
21
+
22
+ from bblean.utils import batched, _num_avail_cpus, _has_files_or_valid_symlinks
23
+ from bblean.analysis import ClusterAnalysis, cluster_analysis
24
+ from bblean._config import TSNE_SEED
25
+
26
+ __all__ = [
27
+ "summary_plot",
28
+ "tsne_plot",
29
+ "umap_plot",
30
+ "pops_plot",
31
+ "pca_plot",
32
+ "dump_mol_images",
33
+ ]
34
+
35
+
36
+ def pops_plot(
37
+ c: ClusterAnalysis,
38
+ /,
39
+ title: str | None = None,
40
+ ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]:
41
+ r"""Distrubution of cluster populations using KDE"""
42
+ fig, ax = plt.subplots()
43
+ cluster_sizes = c.sizes
44
+ sns.kdeplot(
45
+ ax=ax,
46
+ data=cluster_sizes,
47
+ color="tab:purple",
48
+ bw_adjust=0.25,
49
+ gridsize=len(cluster_sizes) // 5,
50
+ fill=True,
51
+ warn_singular=False,
52
+ )
53
+ ax.set_xlabel("Density")
54
+ ax.set_xlabel("Cluster size")
55
+ msg = f"Populations for top {c.clusters_num} largest clusters"
56
+ if c.min_size is not None:
57
+ msg = f"{msg} (min. size = {c.min_size})"
58
+ if title is not None:
59
+ msg = f"{msg} for {title}"
60
+ fig.suptitle(msg)
61
+ return fig, (ax,)
62
+
63
+
64
+ # Similar to "init_plot" in the original bitbirch
65
+ def summary_plot(
66
+ c: ClusterAnalysis,
67
+ /,
68
+ title: str | None = None,
69
+ counts_ylim: int | None = None,
70
+ annotate: bool = True,
71
+ ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]:
72
+ r"""Create a summary plot from a cluster analysis
73
+
74
+ If the analysis contains scaffolds, a scaffold analysis is added to the plot"""
75
+ orange = "tab:orange"
76
+ blue = "tab:blue"
77
+ if mpl.rcParamsDefault["font.size"] == plt.rcParams["font.size"]:
78
+ plt.rcParams["font.size"] = 8
79
+ if annotate:
80
+ fig, ax = plt.subplots(figsize=(5, 2.5), dpi=250, constrained_layout=True)
81
+ else:
82
+ fig, ax = plt.subplots()
83
+
84
+ # Plot and annotate the number of molecules
85
+ label_strs = c.labels.astype(str) # TODO: Is this necessary?
86
+ ax.bar(
87
+ label_strs,
88
+ c.sizes,
89
+ color=blue,
90
+ label="Num. molecules",
91
+ zorder=0,
92
+ )
93
+ ax.set_ylim(0, counts_ylim)
94
+ if annotate:
95
+ for i, mol in enumerate(c.sizes):
96
+ plt.text(
97
+ i,
98
+ mol,
99
+ f"{mol}",
100
+ ha="center",
101
+ va="bottom",
102
+ color="black",
103
+ fontsize=5,
104
+ )
105
+
106
+ if c.has_scaffolds:
107
+ # Plot and annotate the number of unique scaffolds
108
+ plt.bar(
109
+ label_strs,
110
+ c.unique_scaffolds_num,
111
+ color=orange,
112
+ label="Num. unique scaffolds",
113
+ zorder=1,
114
+ )
115
+ if annotate:
116
+ for i, s in enumerate(c.unique_scaffolds_num):
117
+ plt.text(
118
+ i,
119
+ s,
120
+ f"{s}",
121
+ ha="center",
122
+ va="bottom",
123
+ color="white",
124
+ fontsize=5,
125
+ )
126
+
127
+ # Labels
128
+ ax.set_xlabel("Cluster label")
129
+ ax.set_ylabel("Num. molecules")
130
+ ax.set_xticks(range(c.clusters_num))
131
+
132
+ # Plot iSIM
133
+ if c.has_fps:
134
+ ax_isim = ax.twinx()
135
+ ax_isim.plot(
136
+ c.labels - 1,
137
+ c.isims,
138
+ color="tab:green",
139
+ linestyle="dashed",
140
+ linewidth=1.5,
141
+ zorder=5,
142
+ alpha=0.6,
143
+ )
144
+ ax_isim.scatter(
145
+ c.labels - 1,
146
+ c.isims,
147
+ color="tab:green",
148
+ marker="o",
149
+ s=15,
150
+ label="Tanimoto iSIM",
151
+ edgecolor="darkgreen",
152
+ zorder=100,
153
+ alpha=0.6,
154
+ )
155
+ ax_isim.set_ylabel("Tanimoto iSIM (average similarity)")
156
+ ax_isim.set_yticks(np.arange(0, 1.1, 0.1))
157
+ ax_isim.set_ylim(0, 1)
158
+ ax_isim.spines["right"].set_color("tab:green")
159
+ ax_isim.tick_params(colors="tab:green")
160
+ ax_isim.yaxis.label.set_color("tab:green")
161
+ bbox = ax.get_position()
162
+ fig.legend(
163
+ loc="upper right",
164
+ bbox_to_anchor=(bbox.x0 + 0.95 * bbox.width, bbox.y0 + 0.95 * bbox.height),
165
+ )
166
+ if c.has_all_clusters:
167
+ msg = "Metrics of all clusters"
168
+ else:
169
+ msg = f"Metrics of top {c.clusters_num} largest clusters"
170
+ if title is not None:
171
+ msg = f"{msg} for {title}"
172
+ fig.suptitle(msg)
173
+ if not c.has_fps:
174
+ return fig, (ax,)
175
+ return fig, (ax, ax_isim)
176
+
177
+
178
+ def umap_plot(
179
+ c: ClusterAnalysis,
180
+ /,
181
+ title: str | None = None,
182
+ scaling: str = "normalize",
183
+ n_neighbors: int = 15,
184
+ min_dist: float = 0.1,
185
+ metric: str = "euclidean",
186
+ densmap: bool = False,
187
+ workers: int | None = None,
188
+ deterministic: bool = False,
189
+ ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]:
190
+ r"""Create a UMAP plot from a cluster analysis"""
191
+ color_labels: list[int] = []
192
+ for num, label in zip(c.sizes, c.labels):
193
+ color_labels.extend([label - 1] * num) # color labels start with 0
194
+ num_top = c.clusters_num
195
+ if workers is None:
196
+ workers = _num_avail_cpus()
197
+
198
+ # I don't think these should be transformed, like this, only normalized
199
+ if scaling == "normalize":
200
+ fps_scaled = normalize_features(c.top_unpacked_fps)
201
+ elif scaling == "std":
202
+ scaler = StandardScaler()
203
+ fps_scaled = scaler.fit_transform(c.top_unpacked_fps)
204
+ elif scaling == "none":
205
+ fps_scaled = c.top_unpacked_fps
206
+ else:
207
+ raise ValueError(f"Unknown scaling {scaling}")
208
+ fps_umap = umap.UMAP(
209
+ densmap=densmap,
210
+ random_state=42 if deterministic else None,
211
+ n_components=2,
212
+ n_jobs=workers,
213
+ n_neighbors=n_neighbors,
214
+ min_dist=min_dist,
215
+ metric=metric,
216
+ ).fit_transform(fps_scaled)
217
+ fig, ax = plt.subplots(dpi=250, figsize=(4, 3.5))
218
+ scatter = ax.scatter(
219
+ fps_umap[:, 0],
220
+ fps_umap[:, 1],
221
+ c=color_labels,
222
+ cmap=mpl.colors.ListedColormap(colorcet.glasbey_bw_minc_20[:num_top]),
223
+ edgecolors="none",
224
+ alpha=0.5,
225
+ s=2,
226
+ )
227
+ # t-SNE plots *must be square*
228
+ ax.set_aspect("equal", adjustable="box")
229
+ cbar = plt.colorbar(scatter, label="Cluster label")
230
+ cbar.set_ticks(list(range(num_top)))
231
+ cbar.set_ticklabels(list(map(str, range(1, num_top + 1))))
232
+ ax.set_xlabel("UMAP component 1")
233
+ ax.set_ylabel("UMAP component 2")
234
+ if c.has_all_clusters:
235
+ msg = "UMAP of all clusters"
236
+ else:
237
+ msg = f"UMAP of top {num_top} largest clusters"
238
+ if title is not None:
239
+ msg = f"{msg} for {title}"
240
+ fig.suptitle(msg)
241
+ return fig, (ax,)
242
+
243
+
244
+ def pca_plot(
245
+ c: ClusterAnalysis,
246
+ /,
247
+ title: str | None = None,
248
+ scaling: str = "normalize",
249
+ whiten: bool = False,
250
+ ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]:
251
+ r"""Create a t-SNE plot from a cluster analysis"""
252
+ color_labels: list[int] = []
253
+ for num, label in zip(c.sizes, c.labels):
254
+ color_labels.extend([label - 1] * num) # color labels start with 0
255
+ num_top = c.clusters_num
256
+
257
+ # I don't think these should be transformed, like this, only normalized
258
+ if scaling == "normalize":
259
+ fps_scaled = normalize_features(c.top_unpacked_fps)
260
+ elif scaling == "std":
261
+ scaler = StandardScaler()
262
+ fps_scaled = scaler.fit_transform(c.top_unpacked_fps)
263
+ elif scaling == "none":
264
+ fps_scaled = c.top_unpacked_fps
265
+ else:
266
+ raise ValueError(f"Unknown scaling {scaling}")
267
+ fps_pca = PCA(n_components=2, whiten=whiten, random_state=1234).fit_transform(
268
+ fps_scaled
269
+ )
270
+ fig, ax = plt.subplots(dpi=250, figsize=(4, 3.5))
271
+ scatter = ax.scatter(
272
+ fps_pca[:, 0],
273
+ fps_pca[:, 1],
274
+ c=color_labels,
275
+ cmap=mpl.colors.ListedColormap(colorcet.glasbey_bw_minc_20[:num_top]),
276
+ edgecolors="none",
277
+ alpha=0.5,
278
+ s=2,
279
+ )
280
+ # t-SNE plots *must be square*
281
+ ax.set_aspect("equal", adjustable="box")
282
+ cbar = plt.colorbar(scatter, label="Cluster label")
283
+ cbar.set_ticks(list(range(num_top)))
284
+ cbar.set_ticklabels(list(map(str, range(1, num_top + 1))))
285
+ ax.set_xlabel("PCA component 1")
286
+ ax.set_ylabel("PCA component 2")
287
+ if c.has_all_clusters:
288
+ msg = "PCA of all clusters"
289
+ else:
290
+ msg = f"PCA of top {num_top} largest clusters"
291
+ if title is not None:
292
+ msg = f"{msg} for {title}"
293
+ fig.suptitle(msg)
294
+ return fig, (ax,)
295
+
296
+
297
+ def tsne_plot(
298
+ c: ClusterAnalysis,
299
+ /,
300
+ title: str | None = None,
301
+ seed: int | None = TSNE_SEED,
302
+ perplexity: int = 30,
303
+ workers: int | None = None,
304
+ scaling: str = "normalize",
305
+ exaggeration: float | None = None,
306
+ do_pca_init: bool = True,
307
+ multiscale: bool = False,
308
+ pca_reduce: int | None = None,
309
+ metric: str = "euclidean",
310
+ dof: float = 1.0,
311
+ ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]:
312
+ r"""Create a t-SNE plot from a cluster analysis"""
313
+ if workers is None:
314
+ workers = _num_avail_cpus()
315
+ color_labels: list[int] = []
316
+ for num, label in zip(c.sizes, c.labels):
317
+ color_labels.extend([label - 1] * num) # color labels start with 0
318
+ num_top = c.clusters_num
319
+
320
+ # I don't think these should be transformed, like this, only normalized
321
+ if scaling == "normalize":
322
+ fps_scaled = normalize_features(c.top_unpacked_fps)
323
+ elif scaling == "std":
324
+ scaler = StandardScaler()
325
+ fps_scaled = scaler.fit_transform(c.top_unpacked_fps)
326
+ elif scaling == "none":
327
+ fps_scaled = c.top_unpacked_fps
328
+ else:
329
+ raise ValueError(f"Unknown scaling {scaling}")
330
+ if pca_reduce is not None:
331
+ fps_scaled = PCA(n_components=pca_reduce).fit_transform(fps_scaled)
332
+
333
+ # Learning rate is set to N / exaggeration (good default)
334
+ # Early exaggeration defaults to max(12, exaggeration) (good default)
335
+ # exaggeration_iter = 250, normal_iter = 500 (good defaults)
336
+ # "pca" is the method used by Dimitry Kovak et. al. (good default), with some jitter
337
+ # added for extra numerical stability
338
+ # Multiscale may help with medium-sized datasets together with downsampling, but
339
+ # it doesn't do much in my tests.
340
+ # NOTE: Dimensionality reduction with PCA to ~50 features seems to mostly preserve
341
+ # cluster structure
342
+ tsne = TSNE(
343
+ n_components=2,
344
+ perplexity=perplexity,
345
+ random_state=seed,
346
+ n_jobs=workers,
347
+ dof=dof,
348
+ exaggeration=exaggeration, # second-phase exaggeration
349
+ negative_gradient_method="fft", # faster for large datasets
350
+ initialization="pca" if do_pca_init else "random",
351
+ )
352
+ if multiscale:
353
+ fps_tsne = (
354
+ super(TSNE, tsne)
355
+ .fit(
356
+ fps_scaled,
357
+ affinities=Multiscale(
358
+ n_jobs=workers,
359
+ random_state=seed,
360
+ data=fps_scaled,
361
+ perplexities=[perplexity, len(fps_scaled) / 100],
362
+ ),
363
+ initialization="pca" if do_pca_init else "random",
364
+ )
365
+ .view(np.ndarray)
366
+ )
367
+ else:
368
+ fps_tsne = tsne.fit_transform(fps_scaled)
369
+
370
+ fig, ax = plt.subplots(dpi=250, figsize=(4, 3.5))
371
+ scatter = ax.scatter(
372
+ fps_tsne[:, 0],
373
+ fps_tsne[:, 1],
374
+ c=color_labels,
375
+ cmap=mpl.colors.ListedColormap(colorcet.glasbey_bw_minc_20[:num_top]),
376
+ edgecolors="none",
377
+ alpha=0.5,
378
+ s=2,
379
+ )
380
+ # t-SNE plots *must be square*
381
+ ax.set_aspect("equal", adjustable="box")
382
+ cbar = plt.colorbar(scatter, label="Cluster label")
383
+ cbar.set_ticks(list(range(num_top)))
384
+ cbar.set_ticklabels(list(map(str, range(1, num_top + 1))))
385
+ ax.set_xlabel("t-SNE component 1")
386
+ ax.set_ylabel("t-SNE component 2")
387
+ if c.has_all_clusters:
388
+ msg = "t-SNE of all clusters"
389
+ else:
390
+ msg = f"t-SNE of top {num_top} largest clusters"
391
+ if title is not None:
392
+ msg = f"{msg} for {title}"
393
+ fig.suptitle(msg)
394
+ return fig, (ax,)
395
+
396
+
397
+ def dump_mol_images(
398
+ smiles: tp.Iterable[str],
399
+ clusters: list[list[int]],
400
+ cluster_idx: int = 0,
401
+ batch_size: int = 30,
402
+ ) -> None:
403
+ r"""Dump smiles associated with a specific cluster as ``*.png`` image files"""
404
+ if isinstance(smiles, str):
405
+ smiles = [smiles]
406
+ smiles = np.asarray(smiles)
407
+ idxs = clusters[cluster_idx]
408
+ for i, idx_seq in enumerate(batched(idxs, batch_size)):
409
+ mols = []
410
+ for smi in smiles[list(idx_seq)]:
411
+ mol = Chem.MolFromSmiles(smi)
412
+ if mol is None:
413
+ raise ValueError(f"Could not parse smiles {smi}")
414
+ mols.append(mol)
415
+ img = Draw.MolsToGridImage(mols, molsPerRow=5)
416
+ with open(f"cluster_{cluster_idx}_{i}.png", "wb") as f:
417
+ f.write(img.data)
418
+
419
+
420
+ # For internal use, dispatches a visualization workflow and optionally saves
421
+ # plot to disk and/or displays it using mpl
422
+ def _dispatch_visualization(
423
+ clusters_path: Path,
424
+ fn_name: str,
425
+ fn: tp.Callable[..., tp.Any],
426
+ fn_kwargs: tp.Any,
427
+ min_size: int = 0,
428
+ smiles: tp.Iterable[str] = (),
429
+ top: int | None = None,
430
+ n_features: int | None = None,
431
+ input_is_packed: bool = True,
432
+ fps_path: Path | None = None,
433
+ title: str | None = None,
434
+ filename: str | None = None,
435
+ verbose: bool = True,
436
+ save: bool = True,
437
+ show: bool = True,
438
+ ) -> None:
439
+ if clusters_path.is_dir():
440
+ clusters_path = clusters_path / "clusters.pkl"
441
+ with open(clusters_path, mode="rb") as f:
442
+ clusters = pickle.load(f)
443
+ if fps_path is None:
444
+ input_fps_path = clusters_path.parent / "input-fps"
445
+ if input_fps_path.is_dir() and _has_files_or_valid_symlinks(input_fps_path):
446
+ fps_path = input_fps_path
447
+ else:
448
+ if fn_name != "summary":
449
+ msg = "Could not find input fingerprints. Please use --fps-path"
450
+ raise RuntimeError(msg)
451
+ else:
452
+ msg = (
453
+ "Could not find input fingerprints. Please use --fps-path."
454
+ " Summary plot without fingerprints doesn't include isim values"
455
+ )
456
+ warnings.warn(msg)
457
+ if fps_path is None:
458
+ fps_paths = None
459
+ elif fps_path.is_dir():
460
+ fps_paths = sorted(fps_path.glob("*.npy"))
461
+ else:
462
+ fps_paths = [fps_path]
463
+ ca = cluster_analysis(
464
+ clusters,
465
+ fps_paths,
466
+ smiles=smiles,
467
+ top=top,
468
+ n_features=n_features,
469
+ input_is_packed=input_is_packed,
470
+ min_size=min_size,
471
+ )
472
+ fn(ca, title=title, **fn_kwargs)
473
+ if save:
474
+ if filename is None:
475
+ unique_id = format(random.getrandbits(32), "08x")
476
+ filename = f"{fn_name}-{unique_id}.pdf"
477
+ plt.savefig(Path.cwd() / filename)
478
+ if show:
479
+ plt.show()