wsi-toolbox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wsi_toolbox/cli.py ADDED
@@ -0,0 +1,485 @@
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from glob import glob
5
+ from pathlib import Path as P
6
+
7
+ from tqdm import tqdm
8
+ from pydantic import Field
9
+ from pydantic_autocli import param
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import cv2
12
+ import numpy as np
13
+ import pandas as pd
14
+ from matplotlib import pyplot as plt, colors as mcolors
15
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea, VPacker
16
+ import seaborn as sns
17
+ import h5py
18
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from sklearn.cluster import DBSCAN
21
+ from sklearn.decomposition import PCA
22
+ import networkx as nx
23
+ import leidenalg as la
24
+ import igraph as ig
25
+ import hdbscan
26
+ import torch
27
+ from torchvision import transforms
28
+ from torch.amp import autocast
29
+ import timm
30
+
31
+ from . import commands
32
+ from .models import create_model
33
+ from .utils import plot_umap
34
+ from .utils.cli import BaseMLCLI, BaseMLArgs
35
+ from .utils.analysis import leiden_cluster
36
+ from .utils.progress import tqdm_or_st
37
+
38
+
39
+ warnings.filterwarnings('ignore', category=FutureWarning, message='.*force_all_finite.*')
40
+ warnings.filterwarnings('ignore', category=FutureWarning, message="You are using `torch.load` with `weights_only=False`")
41
+
42
+ DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'uni')
43
+
44
+ commands.set_default_progress('tqdm')
45
+ commands.set_default_model(DEFAULT_MODEL)
46
+
47
+ def sigmoid(x):
48
+ return 1 / (1 + np.exp(-x))
49
+
50
+ class CLI(BaseMLCLI):
51
+ class CommonArgs(BaseMLArgs):
52
+ # This includes `--seed` param
53
+ device: str = 'cuda'
54
+ pass
55
+
56
+ class Wsi2h5Args(CommonArgs):
57
+ input_path: str = param(..., l='--in', s='-i')
58
+ output_path: str = param('', l='--out', s='-o')
59
+ patch_size: int = param(256, s='-S')
60
+ overwrite: bool = param(False, s='-O')
61
+ engine: str = param('auto', choices=['auto', 'openslide', 'tifffile'])
62
+ mpp: float = 0
63
+ norotate: bool = False
64
+
65
+ def run_wsi2h5(self, a:Wsi2h5Args):
66
+ output_path = a.output_path
67
+ if not output_path:
68
+ base, ext = os.path.splitext(a.input_path)
69
+ output_path = base + '.h5'
70
+
71
+ if os.path.exists(output_path):
72
+ if not a.overwrite:
73
+ print(f'{output_path} exists. Skipping.')
74
+ return
75
+ print(f'{output_path} exists but overwriting it.')
76
+
77
+ d = os.path.dirname(output_path)
78
+ if d:
79
+ os.makedirs(d, exist_ok=True)
80
+
81
+ # Use new command pattern (progress is auto-set from global config)
82
+ cmd = commands.Wsi2HDF5Command(
83
+ patch_size=a.patch_size,
84
+ engine=a.engine,
85
+ mpp=a.mpp,
86
+ rotate=not a.norotate
87
+ )
88
+ result = cmd(a.input_path, output_path)
89
+ print(f"done: {result['patch_count']} patches extracted")
90
+
91
+ class EmbedArgs(CommonArgs):
92
+ input_path: str = Field(..., l='--in', s='-i')
93
+ batch_size: int = Field(512, s='-B')
94
+ overwrite: bool = Field(False, s='-O')
95
+ model_name: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni'], l='--model', s='-M')
96
+ with_latent_features: bool = Field(False, s='-L')
97
+
98
+ def run_embed(self, a:EmbedArgs):
99
+ commands.set_default_device(a.device)
100
+
101
+ # Use new command pattern
102
+ cmd = commands.PatchEmbeddingCommand(
103
+ batch_size=a.batch_size,
104
+ with_latent=a.with_latent_features,
105
+ overwrite=a.overwrite
106
+ )
107
+ result = cmd(a.input_path)
108
+
109
+ if not result.skipped:
110
+ print(f"done: {result.feature_dim}D features extracted")
111
+
112
+
113
+ class ProcessSlideArgs(CommonArgs):
114
+ input_path: str = Field(..., l='--in', s='-i')
115
+ overwrite: bool = Field(False, s='-O')
116
+
117
+ def run_process_slide(self, a:ProcessSlideArgs):
118
+ from gigapath import slide_encoder
119
+
120
+ with h5py.File(a.input_path, 'a') as f:
121
+ if 'gigapath/slide_feature' in f:
122
+ if not a.overwrite:
123
+ print('feature embeddings are already obtained.')
124
+ return
125
+ if 'slide_feature' in f:
126
+ # migrate
127
+ slide_feature = f['slide_feature'][:]
128
+ f.create_dataset('gigapath/slide_feature', data=slide_feature)
129
+ del f['slide_feature']
130
+ print('Migrated from "slide_feature" to "gigapath/slide_feature"')
131
+ return
132
+ features = f['gigapath/features'][:]
133
+ coords = f['coordinates'][:]
134
+
135
+ features = torch.tensor(features, dtype=torch.float32)[None, ...].to(a.device) # (1, L, D)
136
+ coords = torch.tensor(coords, dtype=torch.float32)[None, ...].to(a.device) # (1, L, 2)
137
+
138
+ print('Loading LongNet...')
139
+ long_net = slide_encoder.create_model(
140
+ 'data/slide_encoder.pth',
141
+ 'gigapath_slide_enc12l768d',
142
+ 1536,
143
+ ).eval().to(a.device)
144
+
145
+ print('LongNet loaded.')
146
+
147
+ with torch.set_grad_enabled(False):
148
+ with autocast(a.device, dtype=torch.float16):
149
+ output = long_net(features, coords)
150
+ # output = output.cpu().detach()
151
+ slide_feature = output[0][0].cpu().detach()
152
+
153
+ print('slide_feature dimension:', slide_feature.shape)
154
+
155
+ with h5py.File(a.input_path, 'a') as f:
156
+ if a.overwrite and 'slide_feature' in f:
157
+ print('Overwriting slide_feature.')
158
+ del f['gigapath/slide_feature']
159
+ f.create_dataset('gigapath/slide_feature', data=slide_feature)
160
+
161
+
162
+ class ClusterArgs(CommonArgs):
163
+ input_paths: list[str] = Field(..., l='--in', s='-i')
164
+ cluster_name: str = Field('', l='--name', s='-n')
165
+ sub: list[int] = Field([], l='--sub', s='-s')
166
+ model: str = Field(DEFAULT_MODEL, choices=['gigapath', 'uni', 'virchow2'])
167
+ resolution: float = 1
168
+ use_umap_embs: float = False
169
+ save: bool = False
170
+ noshow: bool = False
171
+ overwrite: bool = Field(False, s='-O')
172
+
173
+ def run_cluster(self, a:ClusterArgs):
174
+ commands.set_default_model(a.model)
175
+
176
+ # Use new command pattern
177
+ cmd = commands.ClusteringCommand(
178
+ resolution=a.resolution,
179
+ cluster_name=a.cluster_name,
180
+ cluster_filter=a.sub,
181
+ use_umap=a.use_umap_embs,
182
+ overwrite=a.overwrite
183
+ )
184
+ result = cmd(a.input_paths)
185
+
186
+ if len(a.input_paths) > 1:
187
+ # multiple
188
+ dir = os.path.dirname(a.input_paths[0])
189
+ base = fig_path = f'{dir}/{a.name}'
190
+ else:
191
+ base, ext = os.path.splitext(a.input_paths[0])
192
+
193
+ s = ''
194
+ if len(a.sub) > 0:
195
+ s = 'sub-' + '-'.join([str(i) for i in a.sub]) + '_'
196
+ fig_path = f'{base}_{s}umap.png'
197
+
198
+ # Use the new command pattern with plot_umap utility function
199
+ umap_embs = cmd.get_umap_embeddings()
200
+ fig = plot_umap(umap_embs, cmd.total_clusters)
201
+ if a.save:
202
+ fig.savefig(fig_path)
203
+ print(f'wrote {fig_path}')
204
+
205
+ if not a.noshow:
206
+ plt.show()
207
+
208
+
209
+ class ClusterScoresArgs(CommonArgs):
210
+ input_path: str = Field(..., l='--in', s='-i')
211
+ name: str = Field(...)
212
+ clusters: list[int] = Field([], s='-C')
213
+ model: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni', 'none'])
214
+ scaler: str = Field('minmax', choices=['std', 'minmax'])
215
+ save: bool = False
216
+ noshow: bool = False
217
+
218
+ def run_cluster_scores(self, a:ClusterScoresArgs):
219
+ with h5py.File(a.input_path, 'r') as f:
220
+ patch_count = f['metadata/patch_count'][()]
221
+ clusters = f[f'{a.model}/clusters'][:]
222
+ mask = np.isin(clusters, a.clusters)
223
+ masked_clusters = clusters[mask]
224
+ masked_features = f[f'{a.model}/features'][mask]
225
+
226
+ pca = PCA(n_components=1)
227
+ values = pca.fit_transform(masked_features)
228
+
229
+ if a.scaler == 'minmax':
230
+ scaler = MinMaxScaler()
231
+ values = scaler.fit_transform(values)
232
+ elif a.scaler == 'std':
233
+ scaler = StandardScaler()
234
+ values = scaler.fit_transform(values)
235
+ values = sigmoid(values)
236
+ else:
237
+ raise ValueError('Invalid scaler:', a.scaler)
238
+
239
+ data = []
240
+ labels = []
241
+
242
+ for target in a.clusters:
243
+ cluster_values = values[masked_clusters == target].flatten()
244
+ data.append(cluster_values)
245
+ labels.append(f'Cluster {target}')
246
+
247
+ with h5py.File(a.input_path, 'a') as f:
248
+ path = f'{a.model}/scores_{a.name}'
249
+ if path in f:
250
+ del f[path]
251
+ print(f'Deleted {path}')
252
+ vv = np.full(patch_count, np.nan, dtype=values.dtype)
253
+ vv[mask] = values[:, 0]
254
+ f[path] = vv
255
+ print(f'Wrote {path} in {a.input_path}')
256
+
257
+ if not a.noshow:
258
+ plt.figure(figsize=(12, 8))
259
+ sns.set_style('whitegrid')
260
+ ax = plt.subplot(111)
261
+ sns.violinplot(data=data, ax=ax, inner='box', cut=0, zorder=1, alpha=0.5) # cut=0で分布全体を表示
262
+
263
+ for i, d in enumerate(data):
264
+ x = np.random.normal(i, 0.05, size=len(d))
265
+ ax.scatter(x, d, alpha=.8, s=5, color=f'C{i}', zorder=2)
266
+
267
+ ax.set_xticks(np.arange(0, len(labels)))
268
+ ax.set_xticklabels(labels)
269
+ ax.set_ylabel('PCA Values')
270
+ ax.set_title('Distribution of PCA Values by Cluster')
271
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
272
+ plt.tight_layout()
273
+ if a.save:
274
+ p = P(a.input_path)
275
+ fig_path = str(p.parent / f'{p.stem}_score-{a.name}_pca.png')
276
+ plt.savefig(fig_path)
277
+ print(f'wrote {fig_path}')
278
+ plt.show()
279
+
280
+
281
+ class ClusterLatentArgs(CommonArgs):
282
+ input_path: str = Field(..., l='--in', s='-i')
283
+ name: str = ''
284
+ model: str = Field(DEFAULT_MODEL, choices=['gigapath', 'uni', 'virchow2'])
285
+ resolution: float = 1
286
+ use_umap_embs: float = False
287
+ save: bool = False
288
+ noshow: bool = False
289
+ overwrite: bool = Field(False, s='-O')
290
+
291
+ def run_cluster_latent(self, a):
292
+ target_path = f'{a.model}/latent_clusters'
293
+ skip = False
294
+ with h5py.File(a.input_path, 'r') as f:
295
+ patch_count = f['metadata/patch_count'][()]
296
+ features = f[f'{a.model}/latent_features'][:]
297
+ if target_path in f:
298
+ if a.overwrite:
299
+ print(f'overwriting old {target_path} of {a.input_path}')
300
+ else:
301
+ skip = True
302
+ clusters = f[target_path][:]
303
+ # raise RuntimeError(f'{target_path} already exists in {a.input_path}')
304
+
305
+ # scaler = StandardScaler()
306
+ # features = scaler.fit_transform(features)
307
+ s = features.shape
308
+ h = features.reshape(s[0]*s[1], s[-1]) # B*16*16, 3
309
+
310
+ if not skip:
311
+ clusters = leiden_cluster(h,
312
+ umap_emb_func=None,
313
+ resolution=a.resolution,
314
+ n_jobs=-1,
315
+ progress='tqdm')
316
+
317
+ clusters = clusters.reshape(s[0], s[1])
318
+
319
+ with h5py.File(a.input_path, 'a') as f:
320
+ if target_path in f:
321
+ del f[target_path]
322
+ f.create_dataset(target_path, data=clusters)
323
+
324
+ print(features.reshape(s[0]*s[1], -1).shape)
325
+ print(clusters.reshape(s[0]*s[1]).shape)
326
+
327
+ reducer = umap.UMAP(
328
+ # n_neighbors=30,
329
+ # min_dist=0.05,
330
+ n_components=2,
331
+ # random_state=a.seed
332
+ )
333
+
334
+ embs = reducer.fit_transform(features.reshape(s[0]*s[1], -1))
335
+ fig = plot_umap(
336
+ embeddings=embs,
337
+ clusters=clusters.reshape(s[0]*s[1]))
338
+ if a.save:
339
+ p = P(a.input_path)
340
+ fig_path = str(p.parent / f'{p.stem}_latent_umap.png')
341
+ plt.savefig(fig_path)
342
+ print(f'wrote {fig_path}')
343
+ plt.show()
344
+
345
+
346
+ class PreviewArgs(CommonArgs):
347
+ input_path: str = Field(..., l='--in', s='-i')
348
+ output_path: str = Field('', l='--out', s='-o')
349
+ model: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni', 'virchow2'])
350
+ cluster_name: str = Field('', l='--name', s='-N')
351
+ size: int = 64
352
+ open: bool = False
353
+
354
+ def run_preview(self, a):
355
+ output_path = a.output_path
356
+ if not output_path:
357
+ base, ext = os.path.splitext(a.input_path)
358
+ if a.cluster_name:
359
+ output_path = f'{base}_{a.cluster_name}_thumb.jpg'
360
+ else:
361
+ output_path = f'{base}_thumb.jpg'
362
+
363
+ cmd = commands.PreviewClustersCommand(
364
+ size=a.size,
365
+ model_name=a.model
366
+ )
367
+ img = cmd(a.input_path, cluster_name=a.cluster_name)
368
+ img.save(output_path)
369
+ print(f'wrote {output_path}')
370
+
371
+ if a.open:
372
+ os.system(f'xdg-open {output_path}')
373
+
374
+
375
+ class PreviewScoresArgs(CommonArgs):
376
+ input_path: str = Field(..., l='--in', s='-i')
377
+ output_path: str = Field('', l='--out', s='-o')
378
+ model: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni', 'unified', 'none'])
379
+ score_name: str = Field(..., l='--name', s='-N')
380
+ size: int = 64
381
+ open: bool = False
382
+
383
+ def run_preview_scores(self, a):
384
+ output_path = a.output_path
385
+ if not output_path:
386
+ base, ext = os.path.splitext(a.input_path)
387
+ output_path = f'{base}_score-{a.score_name}_thumb.jpg'
388
+
389
+ cmd = commands.PreviewScoresCommand(
390
+ size=a.size,
391
+ model_name=a.model
392
+ )
393
+ img = cmd(a.input_path, score_name=a.score_name)
394
+ img.save(output_path)
395
+ print(f'wrote {output_path}')
396
+
397
+ if a.open:
398
+ os.system(f'xdg-open {output_path}')
399
+
400
+
401
+ class PreviewLatentPcaArgs(CommonArgs):
402
+ input_path: str = Field(..., l='--in', s='-i')
403
+ output_path: str = Field('', l='--out', s='-o')
404
+ model: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni', 'none'])
405
+ size: int = 64
406
+ open: bool = False
407
+
408
+ def run_preview_latent_pca(self, a:PreviewLatentPcaArgs):
409
+ output_path = a.output_path
410
+ if not output_path:
411
+ base, ext = os.path.splitext(a.input_path)
412
+ output_path = f'{base}_latent_pca.jpg'
413
+
414
+ # Use new command pattern
415
+ commands.set_default_model(a.model)
416
+ cmd = commands.PreviewLatentPCACommand(size=a.size)
417
+ img = cmd(a.input_path)
418
+ img.save(output_path)
419
+ print(f'wrote {output_path}')
420
+
421
+ if a.open:
422
+ os.system(f'xdg-open {output_path}')
423
+
424
+
425
+ class PreviewLatentArgs(CommonArgs):
426
+ input_path: str = Field(..., l='--in', s='-i')
427
+ output_path: str = Field('', l='--out', s='-o')
428
+ model: str = Field(DEFAULT_MODEL, choice=['gigapath', 'uni', 'none'])
429
+ size: int = 64
430
+ open: bool = False
431
+
432
+ def run_preview_latent(self, a:PreviewLatentArgs):
433
+ output_path = a.output_path
434
+ if not output_path:
435
+ base, ext = os.path.splitext(a.input_path)
436
+ output_path = f'{base}_latent_clusters.jpg'
437
+
438
+ # Use new command pattern
439
+ commands.set_default_model(a.model)
440
+ cmd = commands.PreviewLatentClusterCommand(size=a.size)
441
+ img = cmd(a.input_path)
442
+ img.save(output_path)
443
+ print(f'wrote {output_path}')
444
+
445
+ if a.open:
446
+ os.system(f'xdg-open {output_path}')
447
+
448
+
449
+ class ExportDziArgs(CommonArgs):
450
+ input_h5: str = Field(..., l='--input', s='-i', description='入力HDF5ファイルパス')
451
+ output_dir: str = Field(..., l='--output', s='-o', description='出力ディレクトリ')
452
+ jpeg_quality: int = Field(90, s='-q', description='JPEG品質(1-100)')
453
+ fill_empty: bool = Field(False, l='--fill-empty', description='空白パッチに黒画像を出力')
454
+
455
+ def run_export_dzi(self, a: ExportDziArgs):
456
+ """Export HDF5 patches to Deep Zoom Image (DZI) format for OpenSeadragon"""
457
+
458
+ # Get name from H5 filename
459
+ name = P(a.input_h5).stem
460
+
461
+ # Use specified output directory as-is
462
+ output_dir = P(a.output_dir)
463
+
464
+ cmd = commands.DziExportCommand(
465
+ jpeg_quality=a.jpeg_quality,
466
+ fill_empty=a.fill_empty
467
+ )
468
+
469
+ result = cmd(
470
+ hdf5_path=a.input_h5,
471
+ output_dir=str(output_dir),
472
+ name=name
473
+ )
474
+
475
+ print(f'Export completed: {result["dzi_path"]}')
476
+
477
+
478
+ def main():
479
+ """Entry point for wsi-toolbox CLI command."""
480
+ cli = CLI()
481
+ cli.run()
482
+
483
+
484
+ if __name__ == '__main__':
485
+ main()
@@ -0,0 +1,92 @@
1
+ """
2
+ Command-based processors for WSI analysis pipeline.
3
+
4
+ Design pattern: __init__ for configuration, __call__ for execution
5
+ """
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from ..utils.progress import tqdm_or_st
10
+
11
+
12
+ # === Global Configuration (Pydantic) ===
13
+ class Config(BaseModel):
14
+ """Global configuration for commands"""
15
+ progress: str = Field(default='tqdm', description="Progress bar backend")
16
+ model_name: str = Field(default='uni', description="Default model name")
17
+ verbose: bool = Field(default=True, description="Verbose output")
18
+ device: str = Field(default='cuda', description="Device for computation")
19
+
20
+
21
+ # Global config instance
22
+ _config = Config()
23
+
24
+
25
+ def set_default_progress(backend: str):
26
+ """Set global default progress backend ('tqdm', 'streamlit', etc.)"""
27
+ _config.progress = backend
28
+
29
+
30
+ def set_default_model(model_name: str):
31
+ """Set global default model ('uni', 'gigapath', 'virchow2')"""
32
+ _config.model_name = model_name
33
+
34
+
35
+ def set_default_device(device: str):
36
+ """Set global default device ('cuda', 'cpu')"""
37
+ _config.device = device
38
+
39
+
40
+ def set_verbose(verbose: bool):
41
+ """Set global verbosity"""
42
+ _config.verbose = verbose
43
+
44
+
45
+ def _get(key: str, value):
46
+ """Get value or fall back to global default"""
47
+ if value is not None:
48
+ return value
49
+ return getattr(_config, key)
50
+
51
+
52
+ def _progress(iterable, **kwargs):
53
+ """Wrapper for tqdm_or_st that uses global config"""
54
+ return tqdm_or_st(iterable, backend=_config.progress, **kwargs)
55
+
56
+
57
+ # Import and export all commands
58
+ from .wsi import Wsi2HDF5Command
59
+ from .patch_embedding import PatchEmbeddingCommand
60
+ from .clustering import ClusteringCommand
61
+ from .preview import (
62
+ BasePreviewCommand,
63
+ PreviewClustersCommand,
64
+ PreviewScoresCommand,
65
+ PreviewLatentPCACommand,
66
+ PreviewLatentClusterCommand,
67
+ )
68
+ from .dzi_export import DziExportCommand
69
+
70
+ __all__ = [
71
+ # Config
72
+ 'Config',
73
+ '_config',
74
+ # Config setters
75
+ 'set_default_progress',
76
+ 'set_default_model',
77
+ 'set_default_device',
78
+ 'set_verbose',
79
+ # Helper functions
80
+ '_get',
81
+ '_progress',
82
+ # Commands
83
+ 'Wsi2HDF5Command',
84
+ 'PatchEmbeddingCommand',
85
+ 'ClusteringCommand',
86
+ 'BasePreviewCommand',
87
+ 'PreviewClustersCommand',
88
+ 'PreviewScoresCommand',
89
+ 'PreviewLatentPCACommand',
90
+ 'PreviewLatentClusterCommand',
91
+ 'DziExportCommand',
92
+ ]