wsi-toolbox 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +122 -0
- wsi_toolbox/app.py +874 -0
- wsi_toolbox/cli.py +599 -0
- wsi_toolbox/commands/__init__.py +66 -0
- wsi_toolbox/commands/clustering.py +198 -0
- wsi_toolbox/commands/data_loader.py +219 -0
- wsi_toolbox/commands/dzi.py +160 -0
- wsi_toolbox/commands/patch_embedding.py +196 -0
- wsi_toolbox/commands/pca.py +206 -0
- wsi_toolbox/commands/preview.py +394 -0
- wsi_toolbox/commands/show.py +171 -0
- wsi_toolbox/commands/umap_embedding.py +174 -0
- wsi_toolbox/commands/wsi.py +223 -0
- wsi_toolbox/common.py +148 -0
- wsi_toolbox/models.py +30 -0
- wsi_toolbox/utils/__init__.py +109 -0
- wsi_toolbox/utils/analysis.py +174 -0
- wsi_toolbox/utils/hdf5_paths.py +232 -0
- wsi_toolbox/utils/plot.py +227 -0
- wsi_toolbox/utils/progress.py +207 -0
- wsi_toolbox/utils/seed.py +26 -0
- wsi_toolbox/utils/st.py +55 -0
- wsi_toolbox/utils/white.py +121 -0
- wsi_toolbox/watcher.py +256 -0
- wsi_toolbox/wsi_files.py +619 -0
- wsi_toolbox-0.2.0.dist-info/METADATA +253 -0
- wsi_toolbox-0.2.0.dist-info/RECORD +30 -0
- wsi_toolbox-0.2.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.2.0.dist-info/entry_points.txt +3 -0
- wsi_toolbox-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UMAP embedding command for dimensionality reduction
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import h5py
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from ..utils.hdf5_paths import build_cluster_path, build_namespace, ensure_groups
|
|
10
|
+
from . import _get, _progress, get_config
|
|
11
|
+
from .data_loader import MultipleContext
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UmapResult(BaseModel):
|
|
15
|
+
"""Result of UMAP embedding operation"""
|
|
16
|
+
|
|
17
|
+
n_samples: int
|
|
18
|
+
n_components: int
|
|
19
|
+
namespace: str
|
|
20
|
+
target_path: str
|
|
21
|
+
skipped: bool = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class UmapCommand:
|
|
25
|
+
"""
|
|
26
|
+
Compute UMAP embeddings from features
|
|
27
|
+
|
|
28
|
+
Usage:
|
|
29
|
+
# Basic UMAP
|
|
30
|
+
cmd = UmapCommand()
|
|
31
|
+
result = cmd('data.h5') # → uni/default/umap
|
|
32
|
+
|
|
33
|
+
# Multi-file UMAP
|
|
34
|
+
cmd = UmapCommand()
|
|
35
|
+
result = cmd(['001.h5', '002.h5']) # → uni/001+002/umap
|
|
36
|
+
|
|
37
|
+
# UMAP for filtered data
|
|
38
|
+
cmd = UmapCommand(parent_filters=[[1,2,3]])
|
|
39
|
+
result = cmd('data.h5') # → uni/default/filter/1+2+3/umap
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
namespace: str | None = None,
|
|
45
|
+
parent_filters: list[list[int]] | None = None,
|
|
46
|
+
n_components: int = 2,
|
|
47
|
+
n_neighbors: int = 15,
|
|
48
|
+
min_dist: float = 0.1,
|
|
49
|
+
metric: str = "euclidean",
|
|
50
|
+
overwrite: bool = False,
|
|
51
|
+
model_name: str | None = None,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize UMAP command
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
namespace: Explicit namespace (None = auto-generate from input paths)
|
|
58
|
+
parent_filters: Hierarchical filters, e.g., [[1,2,3]]
|
|
59
|
+
n_components: Number of UMAP dimensions (default: 2)
|
|
60
|
+
n_neighbors: UMAP n_neighbors parameter (default: 15)
|
|
61
|
+
min_dist: UMAP min_dist parameter (default: 0.1)
|
|
62
|
+
metric: UMAP metric (default: "euclidean")
|
|
63
|
+
overwrite: Whether to overwrite existing UMAP coordinates
|
|
64
|
+
model_name: Model name (None to use global default)
|
|
65
|
+
"""
|
|
66
|
+
self.namespace = namespace
|
|
67
|
+
self.parent_filters = parent_filters or []
|
|
68
|
+
self.n_components = n_components
|
|
69
|
+
self.n_neighbors = n_neighbors
|
|
70
|
+
self.min_dist = min_dist
|
|
71
|
+
self.metric = metric
|
|
72
|
+
self.overwrite = overwrite
|
|
73
|
+
self.model_name = _get("model_name", model_name)
|
|
74
|
+
|
|
75
|
+
# Validate model
|
|
76
|
+
if self.model_name not in ["uni", "gigapath", "virchow2"]:
|
|
77
|
+
raise ValueError(f"Invalid model: {self.model_name}")
|
|
78
|
+
|
|
79
|
+
# Internal state
|
|
80
|
+
self.hdf5_paths = []
|
|
81
|
+
self.umap_embeddings = None
|
|
82
|
+
|
|
83
|
+
def __call__(self, hdf5_paths: str | list[str]) -> UmapResult:
|
|
84
|
+
"""Execute UMAP embedding"""
|
|
85
|
+
import umap # noqa: PLC0415 - lazy load, umap is slow to import
|
|
86
|
+
|
|
87
|
+
# Normalize to list
|
|
88
|
+
if isinstance(hdf5_paths, str):
|
|
89
|
+
hdf5_paths = [hdf5_paths]
|
|
90
|
+
self.hdf5_paths = hdf5_paths
|
|
91
|
+
|
|
92
|
+
# Determine namespace
|
|
93
|
+
if self.namespace is None:
|
|
94
|
+
self.namespace = build_namespace(hdf5_paths)
|
|
95
|
+
elif "+" in self.namespace:
|
|
96
|
+
raise ValueError("Namespace cannot contain '+' (reserved for multi-file auto-generated namespaces)")
|
|
97
|
+
|
|
98
|
+
# Build target path
|
|
99
|
+
target_path = build_cluster_path(self.model_name, self.namespace, filters=self.parent_filters, dataset="umap")
|
|
100
|
+
|
|
101
|
+
# Check if already exists
|
|
102
|
+
if not self.overwrite:
|
|
103
|
+
with h5py.File(hdf5_paths[0], "r") as f:
|
|
104
|
+
if target_path in f:
|
|
105
|
+
umap_coords = f[target_path][:]
|
|
106
|
+
n_samples = np.sum(~np.isnan(umap_coords[:, 0]))
|
|
107
|
+
if get_config().verbose:
|
|
108
|
+
print(f"UMAP already exists at {target_path}")
|
|
109
|
+
return UmapResult(
|
|
110
|
+
n_samples=n_samples,
|
|
111
|
+
n_components=self.n_components,
|
|
112
|
+
namespace=self.namespace,
|
|
113
|
+
target_path=target_path,
|
|
114
|
+
skipped=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Execute with progress tracking
|
|
118
|
+
with _progress(total=3, desc="UMAP") as pbar:
|
|
119
|
+
# Load features
|
|
120
|
+
pbar.set_description("Loading features")
|
|
121
|
+
ctx = MultipleContext(hdf5_paths, self.model_name, self.namespace, self.parent_filters)
|
|
122
|
+
features = ctx.load_features(source="features")
|
|
123
|
+
pbar.update(1)
|
|
124
|
+
|
|
125
|
+
# Compute UMAP
|
|
126
|
+
pbar.set_description("Computing UMAP")
|
|
127
|
+
reducer = umap.UMAP(
|
|
128
|
+
n_components=self.n_components,
|
|
129
|
+
n_neighbors=self.n_neighbors,
|
|
130
|
+
min_dist=self.min_dist,
|
|
131
|
+
metric=self.metric,
|
|
132
|
+
)
|
|
133
|
+
self.umap_embeddings = reducer.fit_transform(features)
|
|
134
|
+
pbar.update(1)
|
|
135
|
+
|
|
136
|
+
# Write results
|
|
137
|
+
pbar.set_description("Writing results")
|
|
138
|
+
self._write_results(ctx, target_path)
|
|
139
|
+
pbar.update(1)
|
|
140
|
+
|
|
141
|
+
# Verbose output after progress bar closes
|
|
142
|
+
if get_config().verbose:
|
|
143
|
+
print(f"Computing UMAP: {len(features)} samples → {self.n_components}D")
|
|
144
|
+
print(f"Wrote {target_path} to {len(hdf5_paths)} file(s)")
|
|
145
|
+
|
|
146
|
+
return UmapResult(
|
|
147
|
+
n_samples=len(features), n_components=self.n_components, namespace=self.namespace, target_path=target_path
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _write_results(self, ctx: MultipleContext, target_path: str):
|
|
151
|
+
"""Write UMAP coordinates to HDF5 files"""
|
|
152
|
+
for file_slice in ctx:
|
|
153
|
+
umap_coords = file_slice.slice(self.umap_embeddings)
|
|
154
|
+
|
|
155
|
+
with h5py.File(file_slice.hdf5_path, "a") as f:
|
|
156
|
+
ensure_groups(f, target_path)
|
|
157
|
+
|
|
158
|
+
if target_path in f:
|
|
159
|
+
del f[target_path]
|
|
160
|
+
|
|
161
|
+
# Fill with NaN for filtered patches
|
|
162
|
+
full_umap = np.full((len(file_slice.mask), self.n_components), np.nan, dtype=umap_coords.dtype)
|
|
163
|
+
full_umap[file_slice.mask] = umap_coords
|
|
164
|
+
|
|
165
|
+
ds = f.create_dataset(target_path, data=full_umap)
|
|
166
|
+
ds.attrs["n_components"] = self.n_components
|
|
167
|
+
ds.attrs["n_neighbors"] = self.n_neighbors
|
|
168
|
+
ds.attrs["min_dist"] = self.min_dist
|
|
169
|
+
ds.attrs["metric"] = self.metric
|
|
170
|
+
ds.attrs["model"] = self.model_name
|
|
171
|
+
|
|
172
|
+
def get_embeddings(self):
|
|
173
|
+
"""Get computed UMAP embeddings"""
|
|
174
|
+
return self.umap_embeddings
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WSI to HDF5 conversion command
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Callable
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import h5py
|
|
10
|
+
import numpy as np
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from ..utils.white import create_white_detector
|
|
14
|
+
from ..wsi_files import create_wsi_file
|
|
15
|
+
from . import _progress, get_config
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Wsi2HDF5Result(BaseModel):
|
|
19
|
+
"""Result of WSI to HDF5 conversion"""
|
|
20
|
+
|
|
21
|
+
mpp: float
|
|
22
|
+
original_mpp: float
|
|
23
|
+
scale: int
|
|
24
|
+
patch_count: int
|
|
25
|
+
patch_size: int
|
|
26
|
+
cols: int
|
|
27
|
+
rows: int
|
|
28
|
+
output_path: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Wsi2HDF5Command:
|
|
32
|
+
"""
|
|
33
|
+
Convert WSI image to HDF5 format with patch extraction
|
|
34
|
+
|
|
35
|
+
Usage:
|
|
36
|
+
# Set global config once
|
|
37
|
+
commands.set_default_progress('tqdm')
|
|
38
|
+
|
|
39
|
+
# Create and run command
|
|
40
|
+
cmd = Wsi2HDF5Command(patch_size=256, engine='auto')
|
|
41
|
+
result = cmd(input_path='image.ndpi', output_path='output.h5')
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
patch_size: int = 256,
|
|
47
|
+
engine: str = "auto",
|
|
48
|
+
mpp: float = 0.5,
|
|
49
|
+
rotate: bool = False,
|
|
50
|
+
white_detector: Callable[[np.ndarray], bool] = None,
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Initialize WSI to HDF5 converter
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
patch_size: Size of patches to extract
|
|
57
|
+
engine: WSI reader engine ('auto', 'openslide', 'tifffile', 'standard')
|
|
58
|
+
mpp: Microns per pixel (for standard images)
|
|
59
|
+
rotate: Whether to rotate patches 180 degrees
|
|
60
|
+
white_detector: Function that takes (H, W, 3) array and returns bool.
|
|
61
|
+
If None, uses legacy is_white_patch with default params.
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
progress and verbose are controlled by global config:
|
|
65
|
+
- commands.set_default_progress('tqdm')
|
|
66
|
+
- commands.set_verbose(True/False)
|
|
67
|
+
"""
|
|
68
|
+
self.patch_size = patch_size
|
|
69
|
+
self.engine = engine
|
|
70
|
+
self.mpp = mpp
|
|
71
|
+
self.rotate = rotate
|
|
72
|
+
|
|
73
|
+
# Set white detection function
|
|
74
|
+
if white_detector is None:
|
|
75
|
+
# Default: use ptp method with default threshold
|
|
76
|
+
self._is_white_patch = create_white_detector("ptp")
|
|
77
|
+
else:
|
|
78
|
+
self._is_white_patch = white_detector
|
|
79
|
+
|
|
80
|
+
def __call__(self, input_path: str, output_path: str) -> Wsi2HDF5Result:
|
|
81
|
+
"""
|
|
82
|
+
Execute WSI to HDF5 conversion
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
input_path: Path to input WSI file
|
|
86
|
+
output_path: Path to output HDF5 file
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Wsi2HDF5Result: Metadata including mpp, scale, patch_count
|
|
90
|
+
"""
|
|
91
|
+
# Create WSI reader
|
|
92
|
+
wsi = create_wsi_file(input_path, engine=self.engine, mpp=self.mpp)
|
|
93
|
+
|
|
94
|
+
# Calculate scale based on mpp
|
|
95
|
+
original_mpp = wsi.get_mpp()
|
|
96
|
+
|
|
97
|
+
if 0.360 < original_mpp < 0.660:
|
|
98
|
+
# mpp ≃ 0.5 mpp
|
|
99
|
+
scale = 1
|
|
100
|
+
elif original_mpp < 0.360:
|
|
101
|
+
scale = 2
|
|
102
|
+
else:
|
|
103
|
+
raise RuntimeError(f"Invalid mpp: {original_mpp:.6f}")
|
|
104
|
+
|
|
105
|
+
mpp = original_mpp * scale
|
|
106
|
+
|
|
107
|
+
# Get image dimensions
|
|
108
|
+
W, H = wsi.get_original_size()
|
|
109
|
+
S = self.patch_size # Scaled patch size
|
|
110
|
+
T = S * scale # Original patch size
|
|
111
|
+
|
|
112
|
+
x_patch_count = W // T
|
|
113
|
+
y_patch_count = H // T
|
|
114
|
+
width = (W // T) * T
|
|
115
|
+
row_count = H // T
|
|
116
|
+
coordinates = []
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
# Create HDF5 file
|
|
120
|
+
with h5py.File(output_path, "w") as f:
|
|
121
|
+
# Write metadata (both as datasets and attrs for migration)
|
|
122
|
+
f.create_dataset("metadata/original_mpp", data=original_mpp)
|
|
123
|
+
f.create_dataset("metadata/original_width", data=W)
|
|
124
|
+
f.create_dataset("metadata/original_height", data=H)
|
|
125
|
+
f.create_dataset("metadata/image_level", data=0)
|
|
126
|
+
f.create_dataset("metadata/mpp", data=mpp)
|
|
127
|
+
f.create_dataset("metadata/scale", data=scale)
|
|
128
|
+
f.create_dataset("metadata/patch_size", data=S)
|
|
129
|
+
f.create_dataset("metadata/cols", data=x_patch_count)
|
|
130
|
+
f.create_dataset("metadata/rows", data=y_patch_count)
|
|
131
|
+
|
|
132
|
+
# Also save as attrs for future migration
|
|
133
|
+
f.attrs["original_mpp"] = original_mpp
|
|
134
|
+
f.attrs["original_width"] = W
|
|
135
|
+
f.attrs["original_height"] = H
|
|
136
|
+
f.attrs["image_level"] = 0
|
|
137
|
+
f.attrs["mpp"] = mpp
|
|
138
|
+
f.attrs["scale"] = scale
|
|
139
|
+
f.attrs["patch_size"] = S
|
|
140
|
+
f.attrs["cols"] = x_patch_count
|
|
141
|
+
f.attrs["rows"] = y_patch_count
|
|
142
|
+
|
|
143
|
+
# Create patches dataset
|
|
144
|
+
total_patches = f.create_dataset(
|
|
145
|
+
"patches",
|
|
146
|
+
shape=(x_patch_count * y_patch_count, S, S, 3),
|
|
147
|
+
dtype=np.uint8,
|
|
148
|
+
chunks=(1, S, S, 3),
|
|
149
|
+
compression="gzip",
|
|
150
|
+
compression_opts=9,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Extract patches row by row
|
|
154
|
+
cursor = 0
|
|
155
|
+
tq = _progress(range(row_count))
|
|
156
|
+
for row in tq:
|
|
157
|
+
# Read one row
|
|
158
|
+
image = wsi.read_region((0, row * T, width, T))
|
|
159
|
+
image = cv2.resize(image, (width // scale, S), interpolation=cv2.INTER_LANCZOS4)
|
|
160
|
+
|
|
161
|
+
# Reshape into patches
|
|
162
|
+
patches = image.reshape(1, S, x_patch_count, S, 3) # (y, h, x, w, 3)
|
|
163
|
+
patches = patches.transpose(0, 2, 1, 3, 4) # (y, x, h, w, 3)
|
|
164
|
+
patches = patches[0]
|
|
165
|
+
|
|
166
|
+
# Filter white patches and collect valid ones
|
|
167
|
+
batch = []
|
|
168
|
+
for col, patch in enumerate(patches):
|
|
169
|
+
if self._is_white_patch(patch):
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
if self.rotate:
|
|
173
|
+
patch = cv2.rotate(patch, cv2.ROTATE_180)
|
|
174
|
+
coordinates.append(((x_patch_count - 1 - col) * S, (y_patch_count - 1 - row) * S))
|
|
175
|
+
else:
|
|
176
|
+
coordinates.append((col * S, row * S))
|
|
177
|
+
|
|
178
|
+
batch.append(patch)
|
|
179
|
+
|
|
180
|
+
# Write batch
|
|
181
|
+
batch = np.array(batch)
|
|
182
|
+
total_patches[cursor : cursor + len(batch), ...] = batch
|
|
183
|
+
cursor += len(batch)
|
|
184
|
+
|
|
185
|
+
tq.set_description(f"Selected {len(batch)}/{len(patches)} patches (row {row}/{y_patch_count})")
|
|
186
|
+
tq.refresh()
|
|
187
|
+
|
|
188
|
+
# Resize to actual patch count and save coordinates
|
|
189
|
+
patch_count = len(coordinates)
|
|
190
|
+
f.create_dataset("coordinates", data=coordinates)
|
|
191
|
+
f["patches"].resize((patch_count, S, S, 3))
|
|
192
|
+
f.create_dataset("metadata/patch_count", data=patch_count)
|
|
193
|
+
f.attrs["patch_count"] = patch_count
|
|
194
|
+
|
|
195
|
+
except BaseException:
|
|
196
|
+
# Clean up incomplete file on error (including Ctrl-C)
|
|
197
|
+
if os.path.exists(output_path):
|
|
198
|
+
os.remove(output_path)
|
|
199
|
+
raise
|
|
200
|
+
|
|
201
|
+
# Verbose output after progress bar closes
|
|
202
|
+
if get_config().verbose:
|
|
203
|
+
print(f"Original mpp: {original_mpp:.6f}")
|
|
204
|
+
print(f"Image mpp: {mpp:.6f}")
|
|
205
|
+
print(f"Target resolutions: {W} x {H}")
|
|
206
|
+
print(f"Obtained resolutions: {x_patch_count * S} x {y_patch_count * S}")
|
|
207
|
+
print(f"Scale: {scale}")
|
|
208
|
+
print(f"Patch size: {T}")
|
|
209
|
+
print(f"Scaled patch size: {S}")
|
|
210
|
+
print(f"Row count: {y_patch_count}")
|
|
211
|
+
print(f"Col count: {x_patch_count}")
|
|
212
|
+
print(f"{patch_count} patches were selected.")
|
|
213
|
+
|
|
214
|
+
return Wsi2HDF5Result(
|
|
215
|
+
mpp=mpp,
|
|
216
|
+
original_mpp=original_mpp,
|
|
217
|
+
scale=scale,
|
|
218
|
+
patch_count=patch_count,
|
|
219
|
+
patch_size=S,
|
|
220
|
+
cols=x_patch_count,
|
|
221
|
+
rows=y_patch_count,
|
|
222
|
+
output_path=output_path,
|
|
223
|
+
)
|
wsi_toolbox/common.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global configuration and settings for WSI-toolbox
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Callable
|
|
7
|
+
|
|
8
|
+
from matplotlib import pyplot as plt
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from .models import MODEL_NAMES, create_foundation_model
|
|
12
|
+
from .utils.progress import Progress
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# === Global Configuration (Pydantic) ===
|
|
16
|
+
class Config(BaseModel):
|
|
17
|
+
"""Global configuration for commands"""
|
|
18
|
+
|
|
19
|
+
progress: str = Field(default="tqdm", description="Progress bar backend")
|
|
20
|
+
model_name: str = Field(default="uni", description="Default model name")
|
|
21
|
+
model_generator: Callable | None = Field(default=None, description="Model generator function")
|
|
22
|
+
verbose: bool = Field(default=True, description="Verbose output")
|
|
23
|
+
device: str = Field(default="cuda", description="Device for computation")
|
|
24
|
+
cluster_cmap: str = Field(default="tab20", description="Cluster colormap name")
|
|
25
|
+
|
|
26
|
+
class Config:
|
|
27
|
+
arbitrary_types_allowed = True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Global config instance
|
|
31
|
+
_config = Config()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_config() -> Config:
|
|
35
|
+
"""Get global configuration instance"""
|
|
36
|
+
return _config
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_default_progress(backend: str):
|
|
40
|
+
"""Set global default progress backend ('tqdm', 'streamlit', etc.)"""
|
|
41
|
+
_config.progress = backend
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def set_default_model(name: str, generator: Callable, label: str | None = None):
|
|
45
|
+
"""Set custom model generator as default
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: Model name (used for file paths, etc.)
|
|
49
|
+
generator: Callable that returns a model instance (e.g., lambda: MyModel())
|
|
50
|
+
label: Display label (defaults to name if not provided)
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> set_default_model('resnet', lambda: torchvision.models.resnet50())
|
|
54
|
+
>>> set_default_model('custom', create_my_model, label='My Custom Model')
|
|
55
|
+
"""
|
|
56
|
+
_config.model_name = name
|
|
57
|
+
_config.model_generator = generator
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def set_default_model_preset(preset_name: str):
|
|
61
|
+
"""Set default model from preset ('uni', 'gigapath', 'virchow2')
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
preset_name: One of 'uni', 'gigapath', 'virchow2'
|
|
65
|
+
"""
|
|
66
|
+
if preset_name not in MODEL_NAMES:
|
|
67
|
+
raise ValueError(f"Invalid preset: {preset_name}. Must be one of {MODEL_NAMES}")
|
|
68
|
+
|
|
69
|
+
_config.model_name = preset_name
|
|
70
|
+
_config.model_generator = partial(create_foundation_model, preset_name)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def create_default_model():
|
|
74
|
+
"""Create a new model instance using the registered generator.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
torch.nn.Module: Fresh model instance
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
RuntimeError: If no model generator is registered
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> set_default_model_preset('uni')
|
|
84
|
+
>>> model = create_default_model() # Creates new UNI model instance
|
|
85
|
+
"""
|
|
86
|
+
if _config.model_generator is None:
|
|
87
|
+
raise RuntimeError(
|
|
88
|
+
"No model generator registered. Call set_default_model() or set_default_model_preset() first."
|
|
89
|
+
)
|
|
90
|
+
return _config.model_generator()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def set_default_device(device: str):
|
|
94
|
+
"""Set global default device ('cuda', 'cpu')"""
|
|
95
|
+
_config.device = device
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def set_verbose(verbose: bool):
|
|
99
|
+
"""Set global verbosity"""
|
|
100
|
+
_config.verbose = verbose
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def set_default_cluster_cmap(cmap_name: str):
|
|
104
|
+
"""Set global cluster colormap ('tab20', 'tab10', 'Set1', etc.)"""
|
|
105
|
+
_config.cluster_cmap = cmap_name
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _get_cluster_color(cluster_id: int):
|
|
109
|
+
"""
|
|
110
|
+
Get color for cluster ID using global colormap
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
cluster_id: Cluster ID
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Color in matplotlib format (array or string)
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
cmap = plt.get_cmap(_config.cluster_cmap)
|
|
120
|
+
return cmap(cluster_id % 20) # Modulo to handle colormaps with limited colors
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _get(key: str, value):
|
|
124
|
+
"""Get value or fall back to global default"""
|
|
125
|
+
if value is not None:
|
|
126
|
+
return value
|
|
127
|
+
return getattr(_config, key)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _progress(iterable=None, total=None, desc="", **kwargs):
|
|
131
|
+
"""Create a progress bar using global config backend"""
|
|
132
|
+
return Progress(iterable=iterable, backend=_config.progress, total=total, desc=desc, **kwargs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = [
|
|
136
|
+
"Config",
|
|
137
|
+
"get_config",
|
|
138
|
+
"set_default_progress",
|
|
139
|
+
"set_default_model",
|
|
140
|
+
"set_default_model_preset",
|
|
141
|
+
"create_default_model",
|
|
142
|
+
"set_default_device",
|
|
143
|
+
"set_verbose",
|
|
144
|
+
"set_default_cluster_cmap",
|
|
145
|
+
"_get_cluster_color",
|
|
146
|
+
"_get",
|
|
147
|
+
"_progress",
|
|
148
|
+
]
|
wsi_toolbox/models.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
MODEL_NAMES = ["uni", "gigapath", "virchow2"]
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_foundation_model(model_name: str):
|
|
5
|
+
"""
|
|
6
|
+
Create a foundation model instance by preset name.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
model_name: One of 'uni', 'gigapath', 'virchow2'
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
torch.nn.Module: Model instance (not moved to device, not in eval mode)
|
|
13
|
+
"""
|
|
14
|
+
# Lazy import: timm/torch are slow to load (~2s), defer until model creation
|
|
15
|
+
import timm # noqa: PLC0415
|
|
16
|
+
import torch # noqa: PLC0415
|
|
17
|
+
from timm.layers import SwiGLUPacked # noqa: PLC0415
|
|
18
|
+
|
|
19
|
+
if model_name == "uni":
|
|
20
|
+
return timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, dynamic_img_size=True, init_values=1e-5)
|
|
21
|
+
|
|
22
|
+
if model_name == "gigapath":
|
|
23
|
+
return timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True, dynamic_img_size=True)
|
|
24
|
+
|
|
25
|
+
if model_name == "virchow2":
|
|
26
|
+
return timm.create_model(
|
|
27
|
+
"hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
raise ValueError(f"Invalid model_name: {model_name}. Must be one of {MODEL_NAMES}")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from matplotlib import colors as mcolors
|
|
5
|
+
from matplotlib import pyplot as plt
|
|
6
|
+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
|
|
7
|
+
from PIL import Image, ImageDraw
|
|
8
|
+
from PIL.Image import Image as ImageType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def yes_no_prompt(question):
|
|
12
|
+
print(f"{question} [Y/n]: ", end="")
|
|
13
|
+
response = input().lower()
|
|
14
|
+
return response == "" or response.startswith("y")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_platform_font():
|
|
18
|
+
if sys.platform == "win32":
|
|
19
|
+
# Windows
|
|
20
|
+
font_path = "C:\\Windows\\Fonts\\msgothic.ttc" # MSゴシック
|
|
21
|
+
elif sys.platform == "darwin":
|
|
22
|
+
# macOS
|
|
23
|
+
font_path = "/System/Library/Fonts/Supplemental/Arial.ttf"
|
|
24
|
+
else:
|
|
25
|
+
# Linux
|
|
26
|
+
# font_path = '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf' # TODO: propagation
|
|
27
|
+
font_path = "/usr/share/fonts/TTF/DejaVuSans.ttf"
|
|
28
|
+
return font_path
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_frame(size, color, text, font):
|
|
32
|
+
frame = Image.new("RGBA", (size, size), (0, 0, 0, 0))
|
|
33
|
+
draw = ImageDraw.Draw(frame)
|
|
34
|
+
draw.rectangle((0, 0, size, size), outline=color, width=4)
|
|
35
|
+
text_color = "white" if mcolors.rgb_to_hsv(mcolors.hex2color(color))[2] < 0.9 else "black"
|
|
36
|
+
bbox = np.array(draw.textbbox((0, 0), text, font=font))
|
|
37
|
+
draw.rectangle((4, 4, bbox[2] + 4, bbox[3] + 4), fill=color)
|
|
38
|
+
draw.text((1, 1), text, font=font, fill=text_color)
|
|
39
|
+
return frame
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def safe_del(hdf_file, key_path):
|
|
43
|
+
"""
|
|
44
|
+
Safely delete a dataset from HDF5 file if it exists
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
hdf_file: h5py.File object
|
|
48
|
+
key_path: Dataset path to delete
|
|
49
|
+
"""
|
|
50
|
+
if key_path in hdf_file:
|
|
51
|
+
del hdf_file[key_path]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def hover_images_on_scatters(scatters, imagess, ax=None, offset=(150, 30)):
|
|
55
|
+
if ax is None:
|
|
56
|
+
ax = plt.gca()
|
|
57
|
+
fig = ax.figure
|
|
58
|
+
|
|
59
|
+
def as_image(image_or_path):
|
|
60
|
+
if isinstance(image_or_path, np.ndarray):
|
|
61
|
+
return image_or_path
|
|
62
|
+
if isinstance(image_or_path, ImageType):
|
|
63
|
+
return image_or_path
|
|
64
|
+
if isinstance(image_or_path, str):
|
|
65
|
+
return Image.open(image_or_path)
|
|
66
|
+
raise RuntimeError("Invalid param", image_or_path)
|
|
67
|
+
|
|
68
|
+
imagebox = OffsetImage(as_image(imagess[0][0]), zoom=0.5)
|
|
69
|
+
imagebox.image.axes = ax
|
|
70
|
+
annot = AnnotationBbox(
|
|
71
|
+
imagebox,
|
|
72
|
+
xy=(0, 0),
|
|
73
|
+
# xybox=(256, 256),
|
|
74
|
+
# xycoords='data',
|
|
75
|
+
boxcoords="offset points",
|
|
76
|
+
# boxcoords=('axes fraction', 'data'),
|
|
77
|
+
pad=0.1,
|
|
78
|
+
arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3"),
|
|
79
|
+
zorder=100,
|
|
80
|
+
)
|
|
81
|
+
annot.set_visible(False)
|
|
82
|
+
ax.add_artist(annot)
|
|
83
|
+
|
|
84
|
+
def hover(event):
|
|
85
|
+
vis = annot.get_visible()
|
|
86
|
+
if event.inaxes != ax:
|
|
87
|
+
return
|
|
88
|
+
for n, (sc, ii) in enumerate(zip(scatters, imagess)):
|
|
89
|
+
cont, index = sc.contains(event)
|
|
90
|
+
if cont:
|
|
91
|
+
i = index["ind"][0]
|
|
92
|
+
pos = sc.get_offsets()[i]
|
|
93
|
+
annot.xy = pos
|
|
94
|
+
annot.xybox = pos + np.array(offset)
|
|
95
|
+
image = as_image(ii[i])
|
|
96
|
+
# text = unique_code[n]
|
|
97
|
+
# annot.set_text(text)
|
|
98
|
+
# annot.get_bbox_patch().set_facecolor(cmap(int(text)/10))
|
|
99
|
+
imagebox.set_data(image)
|
|
100
|
+
annot.set_visible(True)
|
|
101
|
+
fig.canvas.draw_idle()
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
if vis:
|
|
105
|
+
annot.set_visible(False)
|
|
106
|
+
fig.canvas.draw_idle()
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
fig.canvas.mpl_connect("motion_notify_event", hover)
|