wsi-toolbox 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +119 -0
- wsi_toolbox/app.py +753 -0
- wsi_toolbox/cli.py +485 -0
- wsi_toolbox/commands/__init__.py +92 -0
- wsi_toolbox/commands/clustering.py +214 -0
- wsi_toolbox/commands/dzi_export.py +202 -0
- wsi_toolbox/commands/patch_embedding.py +199 -0
- wsi_toolbox/commands/preview.py +335 -0
- wsi_toolbox/commands/wsi.py +196 -0
- wsi_toolbox/exp.py +466 -0
- wsi_toolbox/models.py +38 -0
- wsi_toolbox/utils/__init__.py +153 -0
- wsi_toolbox/utils/analysis.py +127 -0
- wsi_toolbox/utils/cli.py +25 -0
- wsi_toolbox/utils/helpers.py +57 -0
- wsi_toolbox/utils/progress.py +206 -0
- wsi_toolbox/utils/seed.py +21 -0
- wsi_toolbox/utils/st.py +53 -0
- wsi_toolbox/watcher.py +261 -0
- wsi_toolbox/wsi_files.py +187 -0
- wsi_toolbox-0.1.0.dist-info/METADATA +269 -0
- wsi_toolbox-0.1.0.dist-info/RECORD +25 -0
- wsi_toolbox-0.1.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.1.0.dist-info/entry_points.txt +2 -0
- wsi_toolbox-0.1.0.dist-info/licenses/LICENSE +21 -0
wsi_toolbox/exp.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from glob import glob
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
from PIL import Image, ImageDraw
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from matplotlib import pyplot as plt, colors as mcolors
|
|
12
|
+
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
|
13
|
+
import seahorse as sns
|
|
14
|
+
import h5py
|
|
15
|
+
import umap
|
|
16
|
+
from sklearn.preprocessing import StandardScaler
|
|
17
|
+
from sklearn.cluster import DBSCAN
|
|
18
|
+
from sklearn.decomposition import PCA
|
|
19
|
+
import hdbscan
|
|
20
|
+
import torch
|
|
21
|
+
import timm
|
|
22
|
+
|
|
23
|
+
from .utils import BaseMLCLI, BaseMLArgs
|
|
24
|
+
|
|
25
|
+
warnings.filterwarnings('ignore', category=FutureWarning, message='.*force_all_finite.*')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# def is_white_patch(patch, white_threshold=200, white_ratio=0.7):
|
|
30
|
+
# gray_patch = np.mean(patch, axis=-1)
|
|
31
|
+
# white_pixels = np.sum(gray_patch > white_threshold)
|
|
32
|
+
# total_pixels = patch.shape[0] * patch.shape[1]
|
|
33
|
+
# return (white_pixels / total_pixels) > white_ratio
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def is_white_patch_std_sat(patch, rgb_std_threshold=5.0, sat_threshold=10, white_ratio=0.7, verbose=False):
|
|
37
|
+
# white: RGB std < 5.0 and Sat(HSV) < 15
|
|
38
|
+
rgb_std_pixels = np.std(patch, axis=2) < rgb_std_threshold
|
|
39
|
+
patch_hsv = cv2.cvtColor(patch, cv2.COLOR_RGB2HSV)
|
|
40
|
+
sat_pixels = patch_hsv[:, :, 1] < sat_threshold
|
|
41
|
+
white_pixels = np.sum(rgb_std_pixels | sat_pixels)
|
|
42
|
+
total_pixels = patch.shape[0] * patch.shape[1]
|
|
43
|
+
white_ratio_calculated = white_pixels / total_pixels
|
|
44
|
+
if verbose:
|
|
45
|
+
print('whi' if white_ratio_calculated > white_ratio else 'use',
|
|
46
|
+
'and{:.3f} or{:.3f} std{:.3f} sat{:.4f}'.format(
|
|
47
|
+
np.sum(rgb_std_pixels & sat_pixels)/total_pixels,
|
|
48
|
+
np.sum(rgb_std_pixels | sat_pixels)/total_pixels,
|
|
49
|
+
np.sum(rgb_std_pixels)/total_pixels,
|
|
50
|
+
np.sum(sat_pixels)/total_pixels
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
return white_ratio_calculated > white_ratio
|
|
54
|
+
|
|
55
|
+
class CLI(BaseMLCLI):
|
|
56
|
+
class CommonArgs(BaseMLArgs):
|
|
57
|
+
# This includes `--seed` param
|
|
58
|
+
device: str = 'cuda'
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
class ClusterArgs(CommonArgs):
|
|
62
|
+
target: str = Field('cluster', s='-T')
|
|
63
|
+
noshow: bool = False
|
|
64
|
+
|
|
65
|
+
def run_cluster(self, a):
|
|
66
|
+
with h5py.File('data/slide_features.h5', 'r') as f:
|
|
67
|
+
features = f['features'][:]
|
|
68
|
+
df = pd.DataFrame({
|
|
69
|
+
'name': [int((v.decode('utf-8'))) for v in f['names'][:]],
|
|
70
|
+
'filename': [v.decode('utf-8') for v in f['filenames'][:]],
|
|
71
|
+
'order': f['orders'][:],
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
df_clinical = pd.read_excel('./data/clinical_data_cleaned.xlsx', index_col=0)
|
|
75
|
+
df = pd.merge(
|
|
76
|
+
df,
|
|
77
|
+
df_clinical,
|
|
78
|
+
left_on='name',
|
|
79
|
+
right_index=True,
|
|
80
|
+
how='left'
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
print('Loaded features', features.shape)
|
|
84
|
+
scaler = StandardScaler()
|
|
85
|
+
scaled_features = scaler.fit_transform(features)
|
|
86
|
+
# scaled_features = features
|
|
87
|
+
|
|
88
|
+
print('UMAP fitting...')
|
|
89
|
+
reducer = umap.UMAP(
|
|
90
|
+
n_neighbors=10,
|
|
91
|
+
min_dist=0.05,
|
|
92
|
+
n_components=2,
|
|
93
|
+
metric='cosine',
|
|
94
|
+
random_state=a.seed,
|
|
95
|
+
n_jobs=1,
|
|
96
|
+
)
|
|
97
|
+
embedding = reducer.fit_transform(scaled_features)
|
|
98
|
+
print('Loaded features', features.shape)
|
|
99
|
+
|
|
100
|
+
if a.target in [
|
|
101
|
+
'HDBSCAN',
|
|
102
|
+
'CD10 IHC', 'MUM1 IHC', 'HANS', 'BCL6 FISH', 'MYC FISH', 'BCL2 FISH',
|
|
103
|
+
'ECOG PS', 'LDH', 'EN', 'Stage', 'IPI Score',
|
|
104
|
+
'IPI Risk Group (4 Class)', 'RIPI Risk Group', 'Follow-up Status',
|
|
105
|
+
]:
|
|
106
|
+
mode = 'categorical'
|
|
107
|
+
elif a.target in ['MYC IHC', 'BCL2 IHC', 'BCL6 IHC', 'Age', 'OS', 'PFS']:
|
|
108
|
+
mode = 'numeric'
|
|
109
|
+
else:
|
|
110
|
+
raise RuntimeError('invalid target', a.target)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
plt.figure(figsize=(10, 8))
|
|
114
|
+
marker_size = 15
|
|
115
|
+
|
|
116
|
+
if mode == 'categorical':
|
|
117
|
+
if a.target == 'cluster':
|
|
118
|
+
eps = 0.2
|
|
119
|
+
m = hdbscan.HDBSCAN(
|
|
120
|
+
min_cluster_size=5,
|
|
121
|
+
min_samples=5,
|
|
122
|
+
cluster_selection_epsilon=eps,
|
|
123
|
+
metric='euclidean',
|
|
124
|
+
)
|
|
125
|
+
labels = m.fit_predict(embedding)
|
|
126
|
+
n_labels = len(set(labels)) - (1 if -1 in labels else 0)
|
|
127
|
+
else:
|
|
128
|
+
labels = df[a.target].fillna(-1)
|
|
129
|
+
n_labels = len(set(labels))
|
|
130
|
+
cmap = plt.cm.viridis
|
|
131
|
+
|
|
132
|
+
noise_mask = labels == -1
|
|
133
|
+
valid_labels = sorted(list(set(labels[~noise_mask])))
|
|
134
|
+
norm = plt.Normalize(min(valid_labels or [0]), max(valid_labels or [1]))
|
|
135
|
+
for label in valid_labels:
|
|
136
|
+
mask = labels == label
|
|
137
|
+
color = cmap(norm(label))
|
|
138
|
+
plt.scatter(
|
|
139
|
+
embedding[mask, 0], embedding[mask, 1], c=[color],
|
|
140
|
+
s=marker_size, label=f'{a.target} {label}'
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if np.any(noise_mask):
|
|
144
|
+
plt.scatter(
|
|
145
|
+
embedding[noise_mask, 0], embedding[noise_mask, 1], c='gray',
|
|
146
|
+
s=marker_size, marker='x', label='Noise/NaN',
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
values = df[a.target]
|
|
151
|
+
norm = Normalize(vmin=values.min(), vmax=values.max())
|
|
152
|
+
values = values.fillna(-1)
|
|
153
|
+
has_value = values > 0
|
|
154
|
+
cmap = plt.cm.viridis
|
|
155
|
+
scatter = plt.scatter(embedding[has_value, 0], embedding[has_value, 1], c=values[has_value],
|
|
156
|
+
s=marker_size, cmap=cmap, norm=norm, label=a.target,)
|
|
157
|
+
if np.any(has_value):
|
|
158
|
+
plt.scatter(embedding[~has_value, 0], embedding[~has_value, 1], c='gray',
|
|
159
|
+
s=marker_size, marker='x', label='NaN')
|
|
160
|
+
cbar = plt.colorbar(scatter)
|
|
161
|
+
cbar.set_label(a.target)
|
|
162
|
+
|
|
163
|
+
plt.title(f'UMAP + {a.target}')
|
|
164
|
+
plt.xlabel('UMAP Dimension 1')
|
|
165
|
+
plt.ylabel('UMAP Dimension 2')
|
|
166
|
+
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
167
|
+
plt.legend()
|
|
168
|
+
plt.tight_layout()
|
|
169
|
+
os.makedirs('out/umap', exist_ok=True)
|
|
170
|
+
name = a.target.replace(' ', '_')
|
|
171
|
+
plt.savefig(f'out/umap/umap_{name}.png')
|
|
172
|
+
if not a.noshow:
|
|
173
|
+
plt.show()
|
|
174
|
+
|
|
175
|
+
class GlobalClusterArgs(CommonArgs):
|
|
176
|
+
noshow: bool = False
|
|
177
|
+
n_samples: int = 100
|
|
178
|
+
|
|
179
|
+
def run_global_cluster(self, a):
|
|
180
|
+
result = []
|
|
181
|
+
|
|
182
|
+
with h5py.File('data/global_features.h5', 'r') as f:
|
|
183
|
+
global_features = f['global_features'][:]
|
|
184
|
+
lengths = f['lengths'][:]
|
|
185
|
+
|
|
186
|
+
selected_features = []
|
|
187
|
+
iii = []
|
|
188
|
+
|
|
189
|
+
cursor = 0
|
|
190
|
+
for l in lengths:
|
|
191
|
+
slice = global_features[cursor:cursor+l]
|
|
192
|
+
ii = np.random.choice(slice.shape[0], size=a.n_samples, replace=False)
|
|
193
|
+
iii.append(ii)
|
|
194
|
+
selected_features.append(slice[ii])
|
|
195
|
+
cursor += l
|
|
196
|
+
selected_features = np.concatenate(selected_features)
|
|
197
|
+
|
|
198
|
+
features = selected_features
|
|
199
|
+
|
|
200
|
+
print('Loaded features', features.dtype, features.shape)
|
|
201
|
+
scaler = StandardScaler()
|
|
202
|
+
scaled_features = scaler.fit_transform(features)
|
|
203
|
+
|
|
204
|
+
reducer = umap.UMAP(
|
|
205
|
+
n_neighbors=80,
|
|
206
|
+
min_dist=0.3,
|
|
207
|
+
n_components=2,
|
|
208
|
+
metric='cosine',
|
|
209
|
+
# random_state=a.seed
|
|
210
|
+
)
|
|
211
|
+
embedding = reducer.fit_transform(scaled_features)
|
|
212
|
+
|
|
213
|
+
plt.scatter(embedding[:, 0], embedding[:, 1], s=1)
|
|
214
|
+
plt.title(f'UMAP')
|
|
215
|
+
plt.xlabel('UMAP Dimension 1')
|
|
216
|
+
plt.ylabel('UMAP Dimension 2')
|
|
217
|
+
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
218
|
+
plt.tight_layout()
|
|
219
|
+
plt.show()
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class ImageHistArgs(CommonArgs):
|
|
223
|
+
input_path: str = Field(..., l='--in', s='-i')
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def run_image_hist(self, a):
|
|
227
|
+
img = cv2.imread(a.input_path)
|
|
228
|
+
|
|
229
|
+
# BGRからRGBとHSVに変換
|
|
230
|
+
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
231
|
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
|
232
|
+
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
|
233
|
+
|
|
234
|
+
print(is_white_patch_std_sat(rgb, verbose=True))
|
|
235
|
+
|
|
236
|
+
# 8つのサブプロットを作成 (4x3)
|
|
237
|
+
fig, axes = plt.subplots(3, 4, figsize=(20, 10))
|
|
238
|
+
|
|
239
|
+
# RGBヒストグラム
|
|
240
|
+
for i, color in enumerate(['r', 'g', 'b']):
|
|
241
|
+
# ヒストグラムを計算
|
|
242
|
+
hist = cv2.calcHist([rgb], [i], None, [256], [0, 256])
|
|
243
|
+
# プロット
|
|
244
|
+
axes[0, i].plot(hist, color=color)
|
|
245
|
+
axes[0, i].set_xlim([0, 256])
|
|
246
|
+
axes[0, i].set_xticks(range(0, 257, 10)) # 10刻みでメモリを設定
|
|
247
|
+
axes[0, i].set_title(f'{color.upper()} Histogram')
|
|
248
|
+
axes[0, i].grid(True)
|
|
249
|
+
|
|
250
|
+
# RGB平均値ヒストグラム(グレースケール)
|
|
251
|
+
kernel_size = 3
|
|
252
|
+
mean_rgb = cv2.blur(rgb, (kernel_size, kernel_size))
|
|
253
|
+
|
|
254
|
+
# 各ピクセルでRGBの平均を計算してグレースケールに変換
|
|
255
|
+
gray_from_rgb = np.mean(mean_rgb, axis=2).astype(np.uint8)
|
|
256
|
+
|
|
257
|
+
# グレースケール画像のヒストグラムを計算
|
|
258
|
+
gray_hist = cv2.calcHist([gray_from_rgb], [0], None, [256], [0, 256])
|
|
259
|
+
|
|
260
|
+
# ヒストグラムをプロット
|
|
261
|
+
axes[0, 3].plot(gray_hist, color='gray')
|
|
262
|
+
axes[0, 3].set_xlim([0, 256])
|
|
263
|
+
axes[0, 3].set_title('Grayscale (RGB Mean) Histogram')
|
|
264
|
+
axes[0, 3].grid(True)
|
|
265
|
+
|
|
266
|
+
# HSVヒストグラム
|
|
267
|
+
colors = ['r', 'g', 'b'] # プロット用の色(実際のHSVとは無関係)
|
|
268
|
+
titles = ['Hue', 'Saturation', 'Value']
|
|
269
|
+
ranges = [[0, 180], [0, 256], [0, 256]] # H: 0-179, S: 0-255, V: 0-255
|
|
270
|
+
for i in range(3):
|
|
271
|
+
# ヒストグラムを計算
|
|
272
|
+
hist = cv2.calcHist([hsv], [i], None, [ranges[i][1]], ranges[i])
|
|
273
|
+
# プロット
|
|
274
|
+
axes[1, i].plot(hist, color=colors[i])
|
|
275
|
+
axes[1, i].set_xlim(ranges[i])
|
|
276
|
+
axes[1, i].set_xticks(range(0, ranges[i][1] + 1, 10))
|
|
277
|
+
axes[1, i].set_title(f'{titles[i]} Histogram')
|
|
278
|
+
axes[1, i].grid(True)
|
|
279
|
+
|
|
280
|
+
# RGB標準偏差ヒストグラム
|
|
281
|
+
# 標準偏差を計算
|
|
282
|
+
mean_squared = cv2.blur(np.square(rgb.astype(np.float32)), (kernel_size, kernel_size))
|
|
283
|
+
squared_mean = np.square(mean_rgb.astype(np.float32))
|
|
284
|
+
std_rgb = np.sqrt(np.maximum(0, mean_squared - squared_mean)).astype(np.uint8)
|
|
285
|
+
|
|
286
|
+
# RGBチャンネルの標準偏差の平均を計算
|
|
287
|
+
std_gray = np.mean(std_rgb, axis=2).astype(np.uint8)
|
|
288
|
+
|
|
289
|
+
# 表示幅を調整
|
|
290
|
+
max_std_value = np.max(std_gray)
|
|
291
|
+
histogram_range = [0, 50]
|
|
292
|
+
|
|
293
|
+
# ヒストグラムを計算
|
|
294
|
+
std_hist = cv2.calcHist([std_gray], [0], None, [max_std_value+1], histogram_range)
|
|
295
|
+
|
|
296
|
+
# ヒストグラムをプロット
|
|
297
|
+
axes[1, 3].plot(std_hist, color='orange')
|
|
298
|
+
axes[1, 3].set_xlim(histogram_range)
|
|
299
|
+
axes[1, 3].set_title(f'RGB Std Histogram (Range: 0-{max_std_value})')
|
|
300
|
+
axes[1, 3].grid(True)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# LABヒストグラム (3段目)
|
|
304
|
+
lab_colors = ['k', 'g', 'b'] # プロット用の色(L:黒, a:緑, b:青)
|
|
305
|
+
lab_titles = ['Lightness', 'a (green-red)', 'b (blue-yellow)']
|
|
306
|
+
lab_ranges = [[0, 256], [0, 256], [0, 256]] # L: 0-255, a: 0-255, b: 0-255
|
|
307
|
+
|
|
308
|
+
for i in range(3):
|
|
309
|
+
# ヒストグラムを計算
|
|
310
|
+
hist = cv2.calcHist([lab], [i], None, [256], [0, 256])
|
|
311
|
+
# プロット
|
|
312
|
+
axes[2, i].plot(hist, color=lab_colors[i])
|
|
313
|
+
axes[2, i].set_xlim([0, 256])
|
|
314
|
+
axes[2, i].set_xticks(range(0, 257, 10))
|
|
315
|
+
axes[2, i].set_title(f'LAB {lab_titles[i]} Histogram')
|
|
316
|
+
axes[2, i].grid(True)
|
|
317
|
+
|
|
318
|
+
# LAB平均値ヒストグラム
|
|
319
|
+
mean_lab = cv2.blur(lab, (kernel_size, kernel_size))
|
|
320
|
+
# 各ピクセルでLABの平均を計算
|
|
321
|
+
lab_mean = np.mean(mean_lab, axis=2).astype(np.uint8)
|
|
322
|
+
# ヒストグラムを計算
|
|
323
|
+
lab_mean_hist = cv2.calcHist([lab_mean], [0], None, [256], [0, 256])
|
|
324
|
+
# ヒストグラムをプロット
|
|
325
|
+
axes[2, 3].plot(lab_mean_hist, color='purple')
|
|
326
|
+
axes[2, 3].set_xlim([0, 256])
|
|
327
|
+
axes[2, 3].set_title('LAB Mean Histogram')
|
|
328
|
+
axes[2, 3].grid(True)
|
|
329
|
+
|
|
330
|
+
plt.tight_layout()
|
|
331
|
+
plt.show()
|
|
332
|
+
|
|
333
|
+
class PcaDimArgs(CommonArgs):
|
|
334
|
+
input_path: str = Field(..., l='--in', s='-i')
|
|
335
|
+
models: list[str] = Field(['gigapath'], choices=['uni', 'gigapath'])
|
|
336
|
+
|
|
337
|
+
def run_pca_dim(self, a):
|
|
338
|
+
with h5py.File(a.input_path, 'r') as f:
|
|
339
|
+
patch_count = f['metadata/patch_count'][()]
|
|
340
|
+
feature_arrays = []
|
|
341
|
+
for model in a.models:
|
|
342
|
+
path = f'{model}/features'
|
|
343
|
+
if path in f:
|
|
344
|
+
feature_arrays.append(f[path][:])
|
|
345
|
+
else:
|
|
346
|
+
raise RuntimeError(f'"{path}" does not exist. Do `process-patches` first')
|
|
347
|
+
features = np.concatenate(feature_arrays, axis=1)
|
|
348
|
+
|
|
349
|
+
# Run PCA
|
|
350
|
+
pca = PCA().fit(features)
|
|
351
|
+
explained_variance = pca.explained_variance_ratio_
|
|
352
|
+
|
|
353
|
+
# Cumulative explained variance plot
|
|
354
|
+
plt.figure(figsize=(12, 8))
|
|
355
|
+
plt.subplot(2, 1, 1)
|
|
356
|
+
plt.plot(np.cumsum(explained_variance))
|
|
357
|
+
plt.xlabel('Number of Dimensions')
|
|
358
|
+
plt.ylabel('Cumulative Explained Variance')
|
|
359
|
+
plt.grid(True)
|
|
360
|
+
plt.axhline(y=0.9, color='r', linestyle='-', label='90%')
|
|
361
|
+
plt.axhline(y=0.95, color='g', linestyle='-', label='95%')
|
|
362
|
+
plt.legend()
|
|
363
|
+
|
|
364
|
+
# Calculate dimensions needed for 90% and 95% explained variance
|
|
365
|
+
dim_90 = np.argmax(np.cumsum(explained_variance) >= 0.9) + 1
|
|
366
|
+
dim_95 = np.argmax(np.cumsum(explained_variance) >= 0.95) + 1
|
|
367
|
+
plt.title(f"90% Explained Variance: {dim_90} dims, 95% Explained Variance: {dim_95} dims")
|
|
368
|
+
|
|
369
|
+
# Scree plot for Elbow method
|
|
370
|
+
plt.subplot(2, 1, 2)
|
|
371
|
+
plt.plot(explained_variance, 'o-')
|
|
372
|
+
plt.xlabel('Principal Component')
|
|
373
|
+
plt.ylabel('Explained Variance Ratio')
|
|
374
|
+
plt.grid(True)
|
|
375
|
+
plt.title('Scree Plot (Elbow Method)')
|
|
376
|
+
|
|
377
|
+
# Automatically detect elbow point
|
|
378
|
+
# Calculate first derivative
|
|
379
|
+
diffs = np.diff(explained_variance)
|
|
380
|
+
# Calculate second derivative
|
|
381
|
+
diffs2 = np.diff(diffs)
|
|
382
|
+
|
|
383
|
+
# Find index where second derivative is maximum (+2 to correct for dimension reduction from derivatives)
|
|
384
|
+
elbow_idx = np.argmax(np.abs(diffs2)) + 2
|
|
385
|
+
|
|
386
|
+
# Display elbow point on plot
|
|
387
|
+
plt.axvline(x=elbow_idx, color='r', linestyle='--')
|
|
388
|
+
plt.text(elbow_idx + 0.5, explained_variance[elbow_idx],
|
|
389
|
+
f'Elbow Point: {elbow_idx}', color='red')
|
|
390
|
+
|
|
391
|
+
print(f"Dimensions needed for 90% explained variance: {dim_90}")
|
|
392
|
+
print(f"Dimensions needed for 95% explained variance: {dim_95}")
|
|
393
|
+
print(f"Optimal dimensions estimated by Elbow method: {elbow_idx}")
|
|
394
|
+
|
|
395
|
+
plt.tight_layout()
|
|
396
|
+
plt.show()
|
|
397
|
+
|
|
398
|
+
print(elbow_idx, dim_90, dim_95)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def run_embs(self, a):
|
|
404
|
+
paths = [
|
|
405
|
+
'./data/image_to_test/25-0856_tile1.png',
|
|
406
|
+
'./data/image_to_test/25-0856_tile2.png',
|
|
407
|
+
'./data/image_to_test/25-0856_tile3.png',
|
|
408
|
+
]
|
|
409
|
+
images = [np.array(Image.open(f)) for f in paths]
|
|
410
|
+
# img = img.crop((0, 0, 256, 256))
|
|
411
|
+
# x = np.array(img)
|
|
412
|
+
x = np.stack(images)
|
|
413
|
+
print(x.shape)
|
|
414
|
+
|
|
415
|
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to('cuda')
|
|
416
|
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to('cuda')
|
|
417
|
+
|
|
418
|
+
x = (torch.from_numpy(x)/255).permute(0, 3, 1, 2)
|
|
419
|
+
x = x.to(a.device)
|
|
420
|
+
x = (x-mean)/std
|
|
421
|
+
|
|
422
|
+
print('x shape', x.shape)
|
|
423
|
+
|
|
424
|
+
model = create_model('uni').cuda()
|
|
425
|
+
t = model.forward_features(x)
|
|
426
|
+
print('done inference')
|
|
427
|
+
t = t.cpu().detach().numpy()
|
|
428
|
+
|
|
429
|
+
patch_embs, cls_token = t[:, :-1, ...], t[:, -1, ...]
|
|
430
|
+
print('patch_embs', patch_embs.shape, 'cls_token', cls_token.shape)
|
|
431
|
+
|
|
432
|
+
s = patch_embs.shape
|
|
433
|
+
patch_embs_to_pca = patch_embs.reshape(s[0]*s[1], s[-1])
|
|
434
|
+
|
|
435
|
+
print('PCA input', patch_embs_to_pca.shape)
|
|
436
|
+
|
|
437
|
+
pca = PCA(n_components=3)
|
|
438
|
+
values = pca.fit_transform(patch_embs_to_pca)
|
|
439
|
+
|
|
440
|
+
scaler = MinMaxScaler()
|
|
441
|
+
values = scaler.fit_transform(values)
|
|
442
|
+
|
|
443
|
+
imgs = values.reshape(3,16,16,3)
|
|
444
|
+
|
|
445
|
+
plt.figure(figsize=(8, 7))
|
|
446
|
+
|
|
447
|
+
plt.subplot(2, 3, 1)
|
|
448
|
+
plt.imshow(images[0])
|
|
449
|
+
plt.subplot(2, 3, 2)
|
|
450
|
+
plt.imshow(images[1])
|
|
451
|
+
plt.subplot(2, 3, 3)
|
|
452
|
+
plt.imshow(images[2])
|
|
453
|
+
|
|
454
|
+
plt.subplot(2, 3, 4)
|
|
455
|
+
plt.imshow(imgs[0])
|
|
456
|
+
plt.subplot(2, 3, 5)
|
|
457
|
+
plt.imshow(imgs[1])
|
|
458
|
+
plt.subplot(2, 3, 6)
|
|
459
|
+
plt.imshow(imgs[2])
|
|
460
|
+
|
|
461
|
+
plt.tight_layout()
|
|
462
|
+
plt.show()
|
|
463
|
+
|
|
464
|
+
if __name__ == '__main__':
|
|
465
|
+
cli = CLI()
|
|
466
|
+
cli.run()
|
wsi_toolbox/models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import timm
|
|
4
|
+
from timm.layers import SwiGLUPacked
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
MODEL_LABELS = {
|
|
8
|
+
'uni': 'UNI',
|
|
9
|
+
'gigapath': 'Prov-Gigapath',
|
|
10
|
+
'virchow2': 'Virchow2',
|
|
11
|
+
}
|
|
12
|
+
MODEL_NAMES_BY_LABEL = {v: k for k, v in MODEL_LABELS.items()}
|
|
13
|
+
MODEL_NAMES = list(MODEL_LABELS.keys())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_model_label(model_name) -> str:
|
|
17
|
+
return MODEL_LABELS.get(model_name, model_name)
|
|
18
|
+
|
|
19
|
+
def create_model(model_name):
|
|
20
|
+
if model_name == 'uni':
|
|
21
|
+
return timm.create_model('hf-hub:MahmoodLab/uni',
|
|
22
|
+
pretrained=True,
|
|
23
|
+
dynamic_img_size=True,
|
|
24
|
+
init_values=1e-5)
|
|
25
|
+
|
|
26
|
+
if model_name == 'gigapath':
|
|
27
|
+
return timm.create_model('hf_hub:prov-gigapath/prov-gigapath',
|
|
28
|
+
pretrained=True,
|
|
29
|
+
dynamic_img_size=True)
|
|
30
|
+
|
|
31
|
+
if model_name == 'virchow2':
|
|
32
|
+
return timm.create_model('hf-hub:paige-ai/Virchow2',
|
|
33
|
+
pretrained=True,
|
|
34
|
+
mlp_layer=SwiGLUPacked,
|
|
35
|
+
act_layer=torch.nn.SiLU)
|
|
36
|
+
|
|
37
|
+
raise ValueError('Invalid model_name', model_name)
|
|
38
|
+
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import warnings
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from PIL import Image, ImageFont, ImageDraw
|
|
6
|
+
from PIL.Image import Image as ImageType
|
|
7
|
+
from sklearn.decomposition import PCA
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
from matplotlib import pyplot as plt, colors as mcolors
|
|
11
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
|
12
|
+
|
|
13
|
+
from .cli import BaseMLCLI, BaseMLArgs
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def yes_no_prompt(question):
|
|
18
|
+
print(f"{question} [Y/n]: ", end="")
|
|
19
|
+
response = input().lower()
|
|
20
|
+
return response == "" or response.startswith("y")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_platform_font():
|
|
24
|
+
if sys.platform == 'win32':
|
|
25
|
+
# Windows
|
|
26
|
+
font_path = 'C:\\Windows\\Fonts\\msgothic.ttc' # MSゴシック
|
|
27
|
+
elif sys.platform == 'darwin':
|
|
28
|
+
# macOS
|
|
29
|
+
font_path = '/System/Library/Fonts/Supplemental/Arial.ttf'
|
|
30
|
+
else:
|
|
31
|
+
# Linux
|
|
32
|
+
# font_path = '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf' # TODO: propagation
|
|
33
|
+
font_path = '/usr/share/fonts/TTF/DejaVuSans.ttf'
|
|
34
|
+
return font_path
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_frame(size, color, text, font):
|
|
38
|
+
frame = Image.new('RGBA', (size, size), (0, 0, 0, 0))
|
|
39
|
+
draw = ImageDraw.Draw(frame)
|
|
40
|
+
draw.rectangle((0, 0, size, size), outline=color, width=4)
|
|
41
|
+
text_color = 'white' if mcolors.rgb_to_hsv(mcolors.hex2color(color))[2]<0.9 else 'black'
|
|
42
|
+
bbox = np.array(draw.textbbox((0, 0), text, font=font))
|
|
43
|
+
w, h = bbox[2]-bbox[0], bbox[3]-bbox[1]
|
|
44
|
+
draw.rectangle((4, 4, bbox[2]+4, bbox[3]+4), fill=color)
|
|
45
|
+
draw.text((1, 1), text, font=font, fill=text_color)
|
|
46
|
+
return frame
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def plot_umap(embeddings, clusters, title="UMAP + Clustering", figsize=(10, 8)):
|
|
50
|
+
cluster_ids = sorted(list(set(clusters)))
|
|
51
|
+
|
|
52
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
53
|
+
cmap = plt.get_cmap('tab20')
|
|
54
|
+
|
|
55
|
+
for i, cluster_id in enumerate(cluster_ids):
|
|
56
|
+
coords = embeddings[clusters == cluster_id]
|
|
57
|
+
if cluster_id == -1:
|
|
58
|
+
color = 'black'
|
|
59
|
+
label = 'Noise'
|
|
60
|
+
size = 12
|
|
61
|
+
else:
|
|
62
|
+
color = [cmap(cluster_id % 20)]
|
|
63
|
+
label = f'Cluster {cluster_id}'
|
|
64
|
+
size = 7
|
|
65
|
+
plt.scatter(coords[:, 0], coords[:, 1], s=size, c=color, label=label)
|
|
66
|
+
|
|
67
|
+
for cluster_id in cluster_ids:
|
|
68
|
+
if cluster_id < 0:
|
|
69
|
+
continue
|
|
70
|
+
cluster_points = embeddings[clusters == cluster_id]
|
|
71
|
+
if len(cluster_points) < 1:
|
|
72
|
+
continue
|
|
73
|
+
centroid_x = np.mean(cluster_points[:, 0])
|
|
74
|
+
centroid_y = np.mean(cluster_points[:, 1])
|
|
75
|
+
ax.text(centroid_x, centroid_y, str(cluster_id),
|
|
76
|
+
fontsize=12, fontweight='bold',
|
|
77
|
+
ha='center', va='center',
|
|
78
|
+
bbox=dict(facecolor='white', alpha=0.1, edgecolor='none'))
|
|
79
|
+
|
|
80
|
+
plt.title(title)
|
|
81
|
+
plt.xlabel('UMAP Dimension 1')
|
|
82
|
+
plt.ylabel('UMAP Dimension 2')
|
|
83
|
+
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
84
|
+
plt.tight_layout()
|
|
85
|
+
|
|
86
|
+
return fig
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def hover_images_on_scatters(scatters, imagess, ax=None, offset=(150, 30)):
|
|
90
|
+
if ax is None:
|
|
91
|
+
ax = plt.gca()
|
|
92
|
+
fig = ax.figure
|
|
93
|
+
|
|
94
|
+
def as_image(image_or_path):
|
|
95
|
+
if isinstance(image_or_path, np.ndarray):
|
|
96
|
+
return image_or_path
|
|
97
|
+
if isinstance(image_or_path, ImageType):
|
|
98
|
+
return image_or_path
|
|
99
|
+
if isinstance(image_or_path, str):
|
|
100
|
+
return Image.open(image_or_path)
|
|
101
|
+
raise RuntimeError('Invalid param', image_or_path)
|
|
102
|
+
|
|
103
|
+
imagebox = OffsetImage(as_image(imagess[0][0]), zoom=.5)
|
|
104
|
+
imagebox.image.axes = ax
|
|
105
|
+
annot = AnnotationBbox(
|
|
106
|
+
imagebox,
|
|
107
|
+
xy=(0, 0),
|
|
108
|
+
# xybox=(256, 256),
|
|
109
|
+
# xycoords='data',
|
|
110
|
+
boxcoords='offset points',
|
|
111
|
+
# boxcoords=('axes fraction', 'data'),
|
|
112
|
+
pad=0.1,
|
|
113
|
+
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=-0.3'),
|
|
114
|
+
zorder=100)
|
|
115
|
+
annot.set_visible(False)
|
|
116
|
+
ax.add_artist(annot)
|
|
117
|
+
|
|
118
|
+
def hover(event):
|
|
119
|
+
vis = annot.get_visible()
|
|
120
|
+
if event.inaxes != ax:
|
|
121
|
+
return
|
|
122
|
+
for n, (sc, ii) in enumerate(zip(scatters, imagess)):
|
|
123
|
+
cont, index = sc.contains(event)
|
|
124
|
+
if cont:
|
|
125
|
+
i = index['ind'][0]
|
|
126
|
+
pos = sc.get_offsets()[i]
|
|
127
|
+
annot.xy = pos
|
|
128
|
+
annot.xybox = pos + np.array(offset)
|
|
129
|
+
image = as_image(ii[i])
|
|
130
|
+
# text = unique_code[n]
|
|
131
|
+
# annot.set_text(text)
|
|
132
|
+
# annot.get_bbox_patch().set_facecolor(cmap(int(text)/10))
|
|
133
|
+
imagebox.set_data(image)
|
|
134
|
+
annot.set_visible(True)
|
|
135
|
+
fig.canvas.draw_idle()
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
if vis:
|
|
139
|
+
annot.set_visible(False)
|
|
140
|
+
fig.canvas.draw_idle()
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
fig.canvas.mpl_connect('motion_notify_event', hover)
|
|
144
|
+
|
|
145
|
+
def is_in_streamlit_context():
|
|
146
|
+
logging.getLogger("streamlit").setLevel(logging.ERROR)
|
|
147
|
+
warnings.filterwarnings("ignore", module="streamlit.*")
|
|
148
|
+
try:
|
|
149
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
|
150
|
+
ctx = get_script_run_ctx()
|
|
151
|
+
return ctx is not None
|
|
152
|
+
except ImportError:
|
|
153
|
+
return False
|