wsi-toolbox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wsi_toolbox/exp.py ADDED
@@ -0,0 +1,466 @@
1
+ import os
2
+ import warnings
3
+
4
+ from glob import glob
5
+ from tqdm import tqdm
6
+ from pydantic import Field
7
+ from PIL import Image, ImageDraw
8
+ import cv2
9
+ import numpy as np
10
+ import pandas as pd
11
+ from matplotlib import pyplot as plt, colors as mcolors
12
+ from matplotlib.colors import LinearSegmentedColormap, Normalize
13
+ import seahorse as sns
14
+ import h5py
15
+ import umap
16
+ from sklearn.preprocessing import StandardScaler
17
+ from sklearn.cluster import DBSCAN
18
+ from sklearn.decomposition import PCA
19
+ import hdbscan
20
+ import torch
21
+ import timm
22
+
23
+ from .utils import BaseMLCLI, BaseMLArgs
24
+
25
+ warnings.filterwarnings('ignore', category=FutureWarning, message='.*force_all_finite.*')
26
+
27
+
28
+
29
+ # def is_white_patch(patch, white_threshold=200, white_ratio=0.7):
30
+ # gray_patch = np.mean(patch, axis=-1)
31
+ # white_pixels = np.sum(gray_patch > white_threshold)
32
+ # total_pixels = patch.shape[0] * patch.shape[1]
33
+ # return (white_pixels / total_pixels) > white_ratio
34
+
35
+
36
+ def is_white_patch_std_sat(patch, rgb_std_threshold=5.0, sat_threshold=10, white_ratio=0.7, verbose=False):
37
+ # white: RGB std < 5.0 and Sat(HSV) < 15
38
+ rgb_std_pixels = np.std(patch, axis=2) < rgb_std_threshold
39
+ patch_hsv = cv2.cvtColor(patch, cv2.COLOR_RGB2HSV)
40
+ sat_pixels = patch_hsv[:, :, 1] < sat_threshold
41
+ white_pixels = np.sum(rgb_std_pixels | sat_pixels)
42
+ total_pixels = patch.shape[0] * patch.shape[1]
43
+ white_ratio_calculated = white_pixels / total_pixels
44
+ if verbose:
45
+ print('whi' if white_ratio_calculated > white_ratio else 'use',
46
+ 'and{:.3f} or{:.3f} std{:.3f} sat{:.4f}'.format(
47
+ np.sum(rgb_std_pixels & sat_pixels)/total_pixels,
48
+ np.sum(rgb_std_pixels | sat_pixels)/total_pixels,
49
+ np.sum(rgb_std_pixels)/total_pixels,
50
+ np.sum(sat_pixels)/total_pixels
51
+ ),
52
+ )
53
+ return white_ratio_calculated > white_ratio
54
+
55
+ class CLI(BaseMLCLI):
56
+ class CommonArgs(BaseMLArgs):
57
+ # This includes `--seed` param
58
+ device: str = 'cuda'
59
+ pass
60
+
61
+ class ClusterArgs(CommonArgs):
62
+ target: str = Field('cluster', s='-T')
63
+ noshow: bool = False
64
+
65
+ def run_cluster(self, a):
66
+ with h5py.File('data/slide_features.h5', 'r') as f:
67
+ features = f['features'][:]
68
+ df = pd.DataFrame({
69
+ 'name': [int((v.decode('utf-8'))) for v in f['names'][:]],
70
+ 'filename': [v.decode('utf-8') for v in f['filenames'][:]],
71
+ 'order': f['orders'][:],
72
+ })
73
+
74
+ df_clinical = pd.read_excel('./data/clinical_data_cleaned.xlsx', index_col=0)
75
+ df = pd.merge(
76
+ df,
77
+ df_clinical,
78
+ left_on='name',
79
+ right_index=True,
80
+ how='left'
81
+ )
82
+
83
+ print('Loaded features', features.shape)
84
+ scaler = StandardScaler()
85
+ scaled_features = scaler.fit_transform(features)
86
+ # scaled_features = features
87
+
88
+ print('UMAP fitting...')
89
+ reducer = umap.UMAP(
90
+ n_neighbors=10,
91
+ min_dist=0.05,
92
+ n_components=2,
93
+ metric='cosine',
94
+ random_state=a.seed,
95
+ n_jobs=1,
96
+ )
97
+ embedding = reducer.fit_transform(scaled_features)
98
+ print('Loaded features', features.shape)
99
+
100
+ if a.target in [
101
+ 'HDBSCAN',
102
+ 'CD10 IHC', 'MUM1 IHC', 'HANS', 'BCL6 FISH', 'MYC FISH', 'BCL2 FISH',
103
+ 'ECOG PS', 'LDH', 'EN', 'Stage', 'IPI Score',
104
+ 'IPI Risk Group (4 Class)', 'RIPI Risk Group', 'Follow-up Status',
105
+ ]:
106
+ mode = 'categorical'
107
+ elif a.target in ['MYC IHC', 'BCL2 IHC', 'BCL6 IHC', 'Age', 'OS', 'PFS']:
108
+ mode = 'numeric'
109
+ else:
110
+ raise RuntimeError('invalid target', a.target)
111
+
112
+
113
+ plt.figure(figsize=(10, 8))
114
+ marker_size = 15
115
+
116
+ if mode == 'categorical':
117
+ if a.target == 'cluster':
118
+ eps = 0.2
119
+ m = hdbscan.HDBSCAN(
120
+ min_cluster_size=5,
121
+ min_samples=5,
122
+ cluster_selection_epsilon=eps,
123
+ metric='euclidean',
124
+ )
125
+ labels = m.fit_predict(embedding)
126
+ n_labels = len(set(labels)) - (1 if -1 in labels else 0)
127
+ else:
128
+ labels = df[a.target].fillna(-1)
129
+ n_labels = len(set(labels))
130
+ cmap = plt.cm.viridis
131
+
132
+ noise_mask = labels == -1
133
+ valid_labels = sorted(list(set(labels[~noise_mask])))
134
+ norm = plt.Normalize(min(valid_labels or [0]), max(valid_labels or [1]))
135
+ for label in valid_labels:
136
+ mask = labels == label
137
+ color = cmap(norm(label))
138
+ plt.scatter(
139
+ embedding[mask, 0], embedding[mask, 1], c=[color],
140
+ s=marker_size, label=f'{a.target} {label}'
141
+ )
142
+
143
+ if np.any(noise_mask):
144
+ plt.scatter(
145
+ embedding[noise_mask, 0], embedding[noise_mask, 1], c='gray',
146
+ s=marker_size, marker='x', label='Noise/NaN',
147
+ )
148
+
149
+ else:
150
+ values = df[a.target]
151
+ norm = Normalize(vmin=values.min(), vmax=values.max())
152
+ values = values.fillna(-1)
153
+ has_value = values > 0
154
+ cmap = plt.cm.viridis
155
+ scatter = plt.scatter(embedding[has_value, 0], embedding[has_value, 1], c=values[has_value],
156
+ s=marker_size, cmap=cmap, norm=norm, label=a.target,)
157
+ if np.any(has_value):
158
+ plt.scatter(embedding[~has_value, 0], embedding[~has_value, 1], c='gray',
159
+ s=marker_size, marker='x', label='NaN')
160
+ cbar = plt.colorbar(scatter)
161
+ cbar.set_label(a.target)
162
+
163
+ plt.title(f'UMAP + {a.target}')
164
+ plt.xlabel('UMAP Dimension 1')
165
+ plt.ylabel('UMAP Dimension 2')
166
+ # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
167
+ plt.legend()
168
+ plt.tight_layout()
169
+ os.makedirs('out/umap', exist_ok=True)
170
+ name = a.target.replace(' ', '_')
171
+ plt.savefig(f'out/umap/umap_{name}.png')
172
+ if not a.noshow:
173
+ plt.show()
174
+
175
+ class GlobalClusterArgs(CommonArgs):
176
+ noshow: bool = False
177
+ n_samples: int = 100
178
+
179
+ def run_global_cluster(self, a):
180
+ result = []
181
+
182
+ with h5py.File('data/global_features.h5', 'r') as f:
183
+ global_features = f['global_features'][:]
184
+ lengths = f['lengths'][:]
185
+
186
+ selected_features = []
187
+ iii = []
188
+
189
+ cursor = 0
190
+ for l in lengths:
191
+ slice = global_features[cursor:cursor+l]
192
+ ii = np.random.choice(slice.shape[0], size=a.n_samples, replace=False)
193
+ iii.append(ii)
194
+ selected_features.append(slice[ii])
195
+ cursor += l
196
+ selected_features = np.concatenate(selected_features)
197
+
198
+ features = selected_features
199
+
200
+ print('Loaded features', features.dtype, features.shape)
201
+ scaler = StandardScaler()
202
+ scaled_features = scaler.fit_transform(features)
203
+
204
+ reducer = umap.UMAP(
205
+ n_neighbors=80,
206
+ min_dist=0.3,
207
+ n_components=2,
208
+ metric='cosine',
209
+ # random_state=a.seed
210
+ )
211
+ embedding = reducer.fit_transform(scaled_features)
212
+
213
+ plt.scatter(embedding[:, 0], embedding[:, 1], s=1)
214
+ plt.title(f'UMAP')
215
+ plt.xlabel('UMAP Dimension 1')
216
+ plt.ylabel('UMAP Dimension 2')
217
+ # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
218
+ plt.tight_layout()
219
+ plt.show()
220
+
221
+
222
+ class ImageHistArgs(CommonArgs):
223
+ input_path: str = Field(..., l='--in', s='-i')
224
+
225
+
226
+ def run_image_hist(self, a):
227
+ img = cv2.imread(a.input_path)
228
+
229
+ # BGRからRGBとHSVに変換
230
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
231
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
232
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
233
+
234
+ print(is_white_patch_std_sat(rgb, verbose=True))
235
+
236
+ # 8つのサブプロットを作成 (4x3)
237
+ fig, axes = plt.subplots(3, 4, figsize=(20, 10))
238
+
239
+ # RGBヒストグラム
240
+ for i, color in enumerate(['r', 'g', 'b']):
241
+ # ヒストグラムを計算
242
+ hist = cv2.calcHist([rgb], [i], None, [256], [0, 256])
243
+ # プロット
244
+ axes[0, i].plot(hist, color=color)
245
+ axes[0, i].set_xlim([0, 256])
246
+ axes[0, i].set_xticks(range(0, 257, 10)) # 10刻みでメモリを設定
247
+ axes[0, i].set_title(f'{color.upper()} Histogram')
248
+ axes[0, i].grid(True)
249
+
250
+ # RGB平均値ヒストグラム(グレースケール)
251
+ kernel_size = 3
252
+ mean_rgb = cv2.blur(rgb, (kernel_size, kernel_size))
253
+
254
+ # 各ピクセルでRGBの平均を計算してグレースケールに変換
255
+ gray_from_rgb = np.mean(mean_rgb, axis=2).astype(np.uint8)
256
+
257
+ # グレースケール画像のヒストグラムを計算
258
+ gray_hist = cv2.calcHist([gray_from_rgb], [0], None, [256], [0, 256])
259
+
260
+ # ヒストグラムをプロット
261
+ axes[0, 3].plot(gray_hist, color='gray')
262
+ axes[0, 3].set_xlim([0, 256])
263
+ axes[0, 3].set_title('Grayscale (RGB Mean) Histogram')
264
+ axes[0, 3].grid(True)
265
+
266
+ # HSVヒストグラム
267
+ colors = ['r', 'g', 'b'] # プロット用の色(実際のHSVとは無関係)
268
+ titles = ['Hue', 'Saturation', 'Value']
269
+ ranges = [[0, 180], [0, 256], [0, 256]] # H: 0-179, S: 0-255, V: 0-255
270
+ for i in range(3):
271
+ # ヒストグラムを計算
272
+ hist = cv2.calcHist([hsv], [i], None, [ranges[i][1]], ranges[i])
273
+ # プロット
274
+ axes[1, i].plot(hist, color=colors[i])
275
+ axes[1, i].set_xlim(ranges[i])
276
+ axes[1, i].set_xticks(range(0, ranges[i][1] + 1, 10))
277
+ axes[1, i].set_title(f'{titles[i]} Histogram')
278
+ axes[1, i].grid(True)
279
+
280
+ # RGB標準偏差ヒストグラム
281
+ # 標準偏差を計算
282
+ mean_squared = cv2.blur(np.square(rgb.astype(np.float32)), (kernel_size, kernel_size))
283
+ squared_mean = np.square(mean_rgb.astype(np.float32))
284
+ std_rgb = np.sqrt(np.maximum(0, mean_squared - squared_mean)).astype(np.uint8)
285
+
286
+ # RGBチャンネルの標準偏差の平均を計算
287
+ std_gray = np.mean(std_rgb, axis=2).astype(np.uint8)
288
+
289
+ # 表示幅を調整
290
+ max_std_value = np.max(std_gray)
291
+ histogram_range = [0, 50]
292
+
293
+ # ヒストグラムを計算
294
+ std_hist = cv2.calcHist([std_gray], [0], None, [max_std_value+1], histogram_range)
295
+
296
+ # ヒストグラムをプロット
297
+ axes[1, 3].plot(std_hist, color='orange')
298
+ axes[1, 3].set_xlim(histogram_range)
299
+ axes[1, 3].set_title(f'RGB Std Histogram (Range: 0-{max_std_value})')
300
+ axes[1, 3].grid(True)
301
+
302
+
303
+ # LABヒストグラム (3段目)
304
+ lab_colors = ['k', 'g', 'b'] # プロット用の色(L:黒, a:緑, b:青)
305
+ lab_titles = ['Lightness', 'a (green-red)', 'b (blue-yellow)']
306
+ lab_ranges = [[0, 256], [0, 256], [0, 256]] # L: 0-255, a: 0-255, b: 0-255
307
+
308
+ for i in range(3):
309
+ # ヒストグラムを計算
310
+ hist = cv2.calcHist([lab], [i], None, [256], [0, 256])
311
+ # プロット
312
+ axes[2, i].plot(hist, color=lab_colors[i])
313
+ axes[2, i].set_xlim([0, 256])
314
+ axes[2, i].set_xticks(range(0, 257, 10))
315
+ axes[2, i].set_title(f'LAB {lab_titles[i]} Histogram')
316
+ axes[2, i].grid(True)
317
+
318
+ # LAB平均値ヒストグラム
319
+ mean_lab = cv2.blur(lab, (kernel_size, kernel_size))
320
+ # 各ピクセルでLABの平均を計算
321
+ lab_mean = np.mean(mean_lab, axis=2).astype(np.uint8)
322
+ # ヒストグラムを計算
323
+ lab_mean_hist = cv2.calcHist([lab_mean], [0], None, [256], [0, 256])
324
+ # ヒストグラムをプロット
325
+ axes[2, 3].plot(lab_mean_hist, color='purple')
326
+ axes[2, 3].set_xlim([0, 256])
327
+ axes[2, 3].set_title('LAB Mean Histogram')
328
+ axes[2, 3].grid(True)
329
+
330
+ plt.tight_layout()
331
+ plt.show()
332
+
333
+ class PcaDimArgs(CommonArgs):
334
+ input_path: str = Field(..., l='--in', s='-i')
335
+ models: list[str] = Field(['gigapath'], choices=['uni', 'gigapath'])
336
+
337
+ def run_pca_dim(self, a):
338
+ with h5py.File(a.input_path, 'r') as f:
339
+ patch_count = f['metadata/patch_count'][()]
340
+ feature_arrays = []
341
+ for model in a.models:
342
+ path = f'{model}/features'
343
+ if path in f:
344
+ feature_arrays.append(f[path][:])
345
+ else:
346
+ raise RuntimeError(f'"{path}" does not exist. Do `process-patches` first')
347
+ features = np.concatenate(feature_arrays, axis=1)
348
+
349
+ # Run PCA
350
+ pca = PCA().fit(features)
351
+ explained_variance = pca.explained_variance_ratio_
352
+
353
+ # Cumulative explained variance plot
354
+ plt.figure(figsize=(12, 8))
355
+ plt.subplot(2, 1, 1)
356
+ plt.plot(np.cumsum(explained_variance))
357
+ plt.xlabel('Number of Dimensions')
358
+ plt.ylabel('Cumulative Explained Variance')
359
+ plt.grid(True)
360
+ plt.axhline(y=0.9, color='r', linestyle='-', label='90%')
361
+ plt.axhline(y=0.95, color='g', linestyle='-', label='95%')
362
+ plt.legend()
363
+
364
+ # Calculate dimensions needed for 90% and 95% explained variance
365
+ dim_90 = np.argmax(np.cumsum(explained_variance) >= 0.9) + 1
366
+ dim_95 = np.argmax(np.cumsum(explained_variance) >= 0.95) + 1
367
+ plt.title(f"90% Explained Variance: {dim_90} dims, 95% Explained Variance: {dim_95} dims")
368
+
369
+ # Scree plot for Elbow method
370
+ plt.subplot(2, 1, 2)
371
+ plt.plot(explained_variance, 'o-')
372
+ plt.xlabel('Principal Component')
373
+ plt.ylabel('Explained Variance Ratio')
374
+ plt.grid(True)
375
+ plt.title('Scree Plot (Elbow Method)')
376
+
377
+ # Automatically detect elbow point
378
+ # Calculate first derivative
379
+ diffs = np.diff(explained_variance)
380
+ # Calculate second derivative
381
+ diffs2 = np.diff(diffs)
382
+
383
+ # Find index where second derivative is maximum (+2 to correct for dimension reduction from derivatives)
384
+ elbow_idx = np.argmax(np.abs(diffs2)) + 2
385
+
386
+ # Display elbow point on plot
387
+ plt.axvline(x=elbow_idx, color='r', linestyle='--')
388
+ plt.text(elbow_idx + 0.5, explained_variance[elbow_idx],
389
+ f'Elbow Point: {elbow_idx}', color='red')
390
+
391
+ print(f"Dimensions needed for 90% explained variance: {dim_90}")
392
+ print(f"Dimensions needed for 95% explained variance: {dim_95}")
393
+ print(f"Optimal dimensions estimated by Elbow method: {elbow_idx}")
394
+
395
+ plt.tight_layout()
396
+ plt.show()
397
+
398
+ print(elbow_idx, dim_90, dim_95)
399
+
400
+
401
+
402
+
403
+ def run_embs(self, a):
404
+ paths = [
405
+ './data/image_to_test/25-0856_tile1.png',
406
+ './data/image_to_test/25-0856_tile2.png',
407
+ './data/image_to_test/25-0856_tile3.png',
408
+ ]
409
+ images = [np.array(Image.open(f)) for f in paths]
410
+ # img = img.crop((0, 0, 256, 256))
411
+ # x = np.array(img)
412
+ x = np.stack(images)
413
+ print(x.shape)
414
+
415
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to('cuda')
416
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to('cuda')
417
+
418
+ x = (torch.from_numpy(x)/255).permute(0, 3, 1, 2)
419
+ x = x.to(a.device)
420
+ x = (x-mean)/std
421
+
422
+ print('x shape', x.shape)
423
+
424
+ model = create_model('uni').cuda()
425
+ t = model.forward_features(x)
426
+ print('done inference')
427
+ t = t.cpu().detach().numpy()
428
+
429
+ patch_embs, cls_token = t[:, :-1, ...], t[:, -1, ...]
430
+ print('patch_embs', patch_embs.shape, 'cls_token', cls_token.shape)
431
+
432
+ s = patch_embs.shape
433
+ patch_embs_to_pca = patch_embs.reshape(s[0]*s[1], s[-1])
434
+
435
+ print('PCA input', patch_embs_to_pca.shape)
436
+
437
+ pca = PCA(n_components=3)
438
+ values = pca.fit_transform(patch_embs_to_pca)
439
+
440
+ scaler = MinMaxScaler()
441
+ values = scaler.fit_transform(values)
442
+
443
+ imgs = values.reshape(3,16,16,3)
444
+
445
+ plt.figure(figsize=(8, 7))
446
+
447
+ plt.subplot(2, 3, 1)
448
+ plt.imshow(images[0])
449
+ plt.subplot(2, 3, 2)
450
+ plt.imshow(images[1])
451
+ plt.subplot(2, 3, 3)
452
+ plt.imshow(images[2])
453
+
454
+ plt.subplot(2, 3, 4)
455
+ plt.imshow(imgs[0])
456
+ plt.subplot(2, 3, 5)
457
+ plt.imshow(imgs[1])
458
+ plt.subplot(2, 3, 6)
459
+ plt.imshow(imgs[2])
460
+
461
+ plt.tight_layout()
462
+ plt.show()
463
+
464
+ if __name__ == '__main__':
465
+ cli = CLI()
466
+ cli.run()
wsi_toolbox/models.py ADDED
@@ -0,0 +1,38 @@
1
+ import os
2
+ import torch
3
+ import timm
4
+ from timm.layers import SwiGLUPacked
5
+
6
+
7
+ MODEL_LABELS = {
8
+ 'uni': 'UNI',
9
+ 'gigapath': 'Prov-Gigapath',
10
+ 'virchow2': 'Virchow2',
11
+ }
12
+ MODEL_NAMES_BY_LABEL = {v: k for k, v in MODEL_LABELS.items()}
13
+ MODEL_NAMES = list(MODEL_LABELS.keys())
14
+
15
+
16
+ def get_model_label(model_name) -> str:
17
+ return MODEL_LABELS.get(model_name, model_name)
18
+
19
+ def create_model(model_name):
20
+ if model_name == 'uni':
21
+ return timm.create_model('hf-hub:MahmoodLab/uni',
22
+ pretrained=True,
23
+ dynamic_img_size=True,
24
+ init_values=1e-5)
25
+
26
+ if model_name == 'gigapath':
27
+ return timm.create_model('hf_hub:prov-gigapath/prov-gigapath',
28
+ pretrained=True,
29
+ dynamic_img_size=True)
30
+
31
+ if model_name == 'virchow2':
32
+ return timm.create_model('hf-hub:paige-ai/Virchow2',
33
+ pretrained=True,
34
+ mlp_layer=SwiGLUPacked,
35
+ act_layer=torch.nn.SiLU)
36
+
37
+ raise ValueError('Invalid model_name', model_name)
38
+
@@ -0,0 +1,153 @@
1
+ import sys
2
+ import warnings
3
+ import logging
4
+
5
+ from PIL import Image, ImageFont, ImageDraw
6
+ from PIL.Image import Image as ImageType
7
+ from sklearn.decomposition import PCA
8
+ import cv2
9
+ import numpy as np
10
+ from matplotlib import pyplot as plt, colors as mcolors
11
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
12
+
13
+ from .cli import BaseMLCLI, BaseMLArgs
14
+
15
+
16
+
17
+ def yes_no_prompt(question):
18
+ print(f"{question} [Y/n]: ", end="")
19
+ response = input().lower()
20
+ return response == "" or response.startswith("y")
21
+
22
+
23
+ def get_platform_font():
24
+ if sys.platform == 'win32':
25
+ # Windows
26
+ font_path = 'C:\\Windows\\Fonts\\msgothic.ttc' # MSゴシック
27
+ elif sys.platform == 'darwin':
28
+ # macOS
29
+ font_path = '/System/Library/Fonts/Supplemental/Arial.ttf'
30
+ else:
31
+ # Linux
32
+ # font_path = '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf' # TODO: propagation
33
+ font_path = '/usr/share/fonts/TTF/DejaVuSans.ttf'
34
+ return font_path
35
+
36
+
37
+ def create_frame(size, color, text, font):
38
+ frame = Image.new('RGBA', (size, size), (0, 0, 0, 0))
39
+ draw = ImageDraw.Draw(frame)
40
+ draw.rectangle((0, 0, size, size), outline=color, width=4)
41
+ text_color = 'white' if mcolors.rgb_to_hsv(mcolors.hex2color(color))[2]<0.9 else 'black'
42
+ bbox = np.array(draw.textbbox((0, 0), text, font=font))
43
+ w, h = bbox[2]-bbox[0], bbox[3]-bbox[1]
44
+ draw.rectangle((4, 4, bbox[2]+4, bbox[3]+4), fill=color)
45
+ draw.text((1, 1), text, font=font, fill=text_color)
46
+ return frame
47
+
48
+
49
+ def plot_umap(embeddings, clusters, title="UMAP + Clustering", figsize=(10, 8)):
50
+ cluster_ids = sorted(list(set(clusters)))
51
+
52
+ fig, ax = plt.subplots(figsize=figsize)
53
+ cmap = plt.get_cmap('tab20')
54
+
55
+ for i, cluster_id in enumerate(cluster_ids):
56
+ coords = embeddings[clusters == cluster_id]
57
+ if cluster_id == -1:
58
+ color = 'black'
59
+ label = 'Noise'
60
+ size = 12
61
+ else:
62
+ color = [cmap(cluster_id % 20)]
63
+ label = f'Cluster {cluster_id}'
64
+ size = 7
65
+ plt.scatter(coords[:, 0], coords[:, 1], s=size, c=color, label=label)
66
+
67
+ for cluster_id in cluster_ids:
68
+ if cluster_id < 0:
69
+ continue
70
+ cluster_points = embeddings[clusters == cluster_id]
71
+ if len(cluster_points) < 1:
72
+ continue
73
+ centroid_x = np.mean(cluster_points[:, 0])
74
+ centroid_y = np.mean(cluster_points[:, 1])
75
+ ax.text(centroid_x, centroid_y, str(cluster_id),
76
+ fontsize=12, fontweight='bold',
77
+ ha='center', va='center',
78
+ bbox=dict(facecolor='white', alpha=0.1, edgecolor='none'))
79
+
80
+ plt.title(title)
81
+ plt.xlabel('UMAP Dimension 1')
82
+ plt.ylabel('UMAP Dimension 2')
83
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
84
+ plt.tight_layout()
85
+
86
+ return fig
87
+
88
+
89
+ def hover_images_on_scatters(scatters, imagess, ax=None, offset=(150, 30)):
90
+ if ax is None:
91
+ ax = plt.gca()
92
+ fig = ax.figure
93
+
94
+ def as_image(image_or_path):
95
+ if isinstance(image_or_path, np.ndarray):
96
+ return image_or_path
97
+ if isinstance(image_or_path, ImageType):
98
+ return image_or_path
99
+ if isinstance(image_or_path, str):
100
+ return Image.open(image_or_path)
101
+ raise RuntimeError('Invalid param', image_or_path)
102
+
103
+ imagebox = OffsetImage(as_image(imagess[0][0]), zoom=.5)
104
+ imagebox.image.axes = ax
105
+ annot = AnnotationBbox(
106
+ imagebox,
107
+ xy=(0, 0),
108
+ # xybox=(256, 256),
109
+ # xycoords='data',
110
+ boxcoords='offset points',
111
+ # boxcoords=('axes fraction', 'data'),
112
+ pad=0.1,
113
+ arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=-0.3'),
114
+ zorder=100)
115
+ annot.set_visible(False)
116
+ ax.add_artist(annot)
117
+
118
+ def hover(event):
119
+ vis = annot.get_visible()
120
+ if event.inaxes != ax:
121
+ return
122
+ for n, (sc, ii) in enumerate(zip(scatters, imagess)):
123
+ cont, index = sc.contains(event)
124
+ if cont:
125
+ i = index['ind'][0]
126
+ pos = sc.get_offsets()[i]
127
+ annot.xy = pos
128
+ annot.xybox = pos + np.array(offset)
129
+ image = as_image(ii[i])
130
+ # text = unique_code[n]
131
+ # annot.set_text(text)
132
+ # annot.get_bbox_patch().set_facecolor(cmap(int(text)/10))
133
+ imagebox.set_data(image)
134
+ annot.set_visible(True)
135
+ fig.canvas.draw_idle()
136
+ return
137
+
138
+ if vis:
139
+ annot.set_visible(False)
140
+ fig.canvas.draw_idle()
141
+ return
142
+
143
+ fig.canvas.mpl_connect('motion_notify_event', hover)
144
+
145
+ def is_in_streamlit_context():
146
+ logging.getLogger("streamlit").setLevel(logging.ERROR)
147
+ warnings.filterwarnings("ignore", module="streamlit.*")
148
+ try:
149
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
150
+ ctx = get_script_run_ctx()
151
+ return ctx is not None
152
+ except ImportError:
153
+ return False