wsi-toolbox 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +119 -0
- wsi_toolbox/app.py +753 -0
- wsi_toolbox/cli.py +485 -0
- wsi_toolbox/commands/__init__.py +92 -0
- wsi_toolbox/commands/clustering.py +214 -0
- wsi_toolbox/commands/dzi_export.py +202 -0
- wsi_toolbox/commands/patch_embedding.py +199 -0
- wsi_toolbox/commands/preview.py +335 -0
- wsi_toolbox/commands/wsi.py +196 -0
- wsi_toolbox/exp.py +466 -0
- wsi_toolbox/models.py +38 -0
- wsi_toolbox/utils/__init__.py +153 -0
- wsi_toolbox/utils/analysis.py +127 -0
- wsi_toolbox/utils/cli.py +25 -0
- wsi_toolbox/utils/helpers.py +57 -0
- wsi_toolbox/utils/progress.py +206 -0
- wsi_toolbox/utils/seed.py +21 -0
- wsi_toolbox/utils/st.py +53 -0
- wsi_toolbox/watcher.py +261 -0
- wsi_toolbox/wsi_files.py +187 -0
- wsi_toolbox-0.1.0.dist-info/METADATA +269 -0
- wsi_toolbox-0.1.0.dist-info/RECORD +25 -0
- wsi_toolbox-0.1.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.1.0.dist-info/entry_points.txt +2 -0
- wsi_toolbox-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Clustering command for WSI features
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import h5py
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from sklearn.preprocessing import StandardScaler
|
|
9
|
+
|
|
10
|
+
from ..utils.analysis import leiden_cluster
|
|
11
|
+
from . import _config, _get
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ClusteringResult(BaseModel):
|
|
15
|
+
"""Result of clustering operation"""
|
|
16
|
+
cluster_count: int
|
|
17
|
+
feature_count: int
|
|
18
|
+
target_path: str
|
|
19
|
+
skipped: bool = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ClusteringCommand:
|
|
23
|
+
"""
|
|
24
|
+
Perform clustering on extracted features
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
# Set global config once
|
|
28
|
+
commands.set_default_model('gigapath')
|
|
29
|
+
|
|
30
|
+
# Create and run command
|
|
31
|
+
cmd = ClusteringCommand(resolution=1.0)
|
|
32
|
+
result = cmd(hdf5_paths=['data1.h5', 'data2.h5'])
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self,
|
|
36
|
+
resolution: float = 1.0,
|
|
37
|
+
cluster_name: str = '',
|
|
38
|
+
cluster_filter: list[int] | None = None,
|
|
39
|
+
use_umap: bool = False,
|
|
40
|
+
overwrite: bool = False,
|
|
41
|
+
model_name: str | None = None):
|
|
42
|
+
"""
|
|
43
|
+
Initialize clustering command
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
resolution: Leiden clustering resolution
|
|
47
|
+
cluster_name: Name for multi-file clustering
|
|
48
|
+
cluster_filter: Filter to specific clusters for sub-clustering
|
|
49
|
+
use_umap: Whether to use UMAP embeddings for clustering
|
|
50
|
+
overwrite: Whether to overwrite existing clusters
|
|
51
|
+
model_name: Model name (None to use global default)
|
|
52
|
+
"""
|
|
53
|
+
self.resolution = resolution
|
|
54
|
+
self.cluster_name = cluster_name
|
|
55
|
+
self.cluster_filter = cluster_filter or []
|
|
56
|
+
self.use_umap = use_umap
|
|
57
|
+
self.overwrite = overwrite
|
|
58
|
+
self.model_name = _get('model_name', model_name)
|
|
59
|
+
|
|
60
|
+
# Validate model
|
|
61
|
+
if self.model_name not in ['uni', 'gigapath', 'virchow2']:
|
|
62
|
+
raise ValueError(f'Invalid model: {self.model_name}')
|
|
63
|
+
|
|
64
|
+
self.sub_clustering = len(self.cluster_filter) > 0
|
|
65
|
+
|
|
66
|
+
# Internal state (initialized in __call__)
|
|
67
|
+
self.hdf5_paths = []
|
|
68
|
+
self.masks = []
|
|
69
|
+
self.features = None
|
|
70
|
+
self.total_clusters = None
|
|
71
|
+
self.umap_embeddings = None
|
|
72
|
+
|
|
73
|
+
def __call__(self, hdf5_paths: str | list[str]) -> ClusteringResult:
|
|
74
|
+
"""
|
|
75
|
+
Execute clustering
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
hdf5_paths: Single HDF5 path or list of paths
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
ClusteringResult: Result metadata (cluster_count, etc.)
|
|
82
|
+
"""
|
|
83
|
+
# Normalize to list
|
|
84
|
+
if isinstance(hdf5_paths, str):
|
|
85
|
+
hdf5_paths = [hdf5_paths]
|
|
86
|
+
|
|
87
|
+
self.hdf5_paths = hdf5_paths
|
|
88
|
+
multi = len(hdf5_paths) > 1
|
|
89
|
+
|
|
90
|
+
# Validate multi-file clustering
|
|
91
|
+
if multi and not self.cluster_name:
|
|
92
|
+
raise RuntimeError('Multiple files provided but cluster_name was not specified.')
|
|
93
|
+
|
|
94
|
+
# Determine cluster path
|
|
95
|
+
if multi:
|
|
96
|
+
clusters_path = f'{self.model_name}/clusters_{self.cluster_name}'
|
|
97
|
+
else:
|
|
98
|
+
clusters_path = f'{self.model_name}/clusters'
|
|
99
|
+
|
|
100
|
+
# Load features
|
|
101
|
+
self._load_features(clusters_path)
|
|
102
|
+
|
|
103
|
+
# Check if already exists
|
|
104
|
+
if not self.sub_clustering and hasattr(self, 'has_clusters') and self.has_clusters and not self.overwrite:
|
|
105
|
+
if _config.verbose:
|
|
106
|
+
print('Skip clustering (already exists)')
|
|
107
|
+
return ClusteringResult(
|
|
108
|
+
cluster_count=len(np.unique(self.total_clusters)),
|
|
109
|
+
feature_count=len(self.features),
|
|
110
|
+
target_path=clusters_path,
|
|
111
|
+
skipped=True
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Perform clustering
|
|
115
|
+
self.total_clusters = leiden_cluster(
|
|
116
|
+
self.features,
|
|
117
|
+
umap_emb_func=self.get_umap_embeddings if self.use_umap else None,
|
|
118
|
+
resolution=self.resolution,
|
|
119
|
+
progress=_config.progress
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Write results
|
|
123
|
+
target_path = clusters_path
|
|
124
|
+
if self.sub_clustering:
|
|
125
|
+
suffix = '_sub' + '-'.join(map(str, self.cluster_filter))
|
|
126
|
+
target_path = target_path + suffix
|
|
127
|
+
|
|
128
|
+
if _config.verbose:
|
|
129
|
+
print(f'Writing to {target_path}')
|
|
130
|
+
|
|
131
|
+
cursor = 0
|
|
132
|
+
for hdf5_path, mask in zip(self.hdf5_paths, self.masks):
|
|
133
|
+
count = np.sum(mask)
|
|
134
|
+
clusters = self.total_clusters[cursor:cursor + count]
|
|
135
|
+
cursor += count
|
|
136
|
+
|
|
137
|
+
with h5py.File(hdf5_path, 'a') as f:
|
|
138
|
+
if target_path in f:
|
|
139
|
+
del f[target_path]
|
|
140
|
+
|
|
141
|
+
# Fill with -1 for filtered patches
|
|
142
|
+
full_clusters = np.full(len(mask), -1, dtype=clusters.dtype)
|
|
143
|
+
full_clusters[mask] = clusters
|
|
144
|
+
f.create_dataset(target_path, data=full_clusters)
|
|
145
|
+
|
|
146
|
+
cluster_count = len(np.unique(self.total_clusters))
|
|
147
|
+
|
|
148
|
+
return ClusteringResult(
|
|
149
|
+
cluster_count=cluster_count,
|
|
150
|
+
feature_count=len(self.features),
|
|
151
|
+
target_path=target_path
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _load_features(self, clusters_path: str):
|
|
155
|
+
"""Load features from HDF5 files"""
|
|
156
|
+
featuress = []
|
|
157
|
+
clusterss = []
|
|
158
|
+
self.masks = []
|
|
159
|
+
|
|
160
|
+
for hdf5_path in self.hdf5_paths:
|
|
161
|
+
with h5py.File(hdf5_path, 'r') as f:
|
|
162
|
+
patch_count = f['metadata/patch_count'][()]
|
|
163
|
+
|
|
164
|
+
# Check existing clusters
|
|
165
|
+
if clusters_path in f:
|
|
166
|
+
clusters = f[clusters_path][:]
|
|
167
|
+
else:
|
|
168
|
+
clusters = None
|
|
169
|
+
|
|
170
|
+
# Create mask
|
|
171
|
+
if self.cluster_filter:
|
|
172
|
+
if clusters is None:
|
|
173
|
+
raise RuntimeError('Sub-clustering requires pre-computed clusters')
|
|
174
|
+
mask = np.isin(clusters, self.cluster_filter)
|
|
175
|
+
else:
|
|
176
|
+
mask = np.ones(patch_count, dtype=bool)
|
|
177
|
+
|
|
178
|
+
self.masks.append(mask)
|
|
179
|
+
|
|
180
|
+
# Load features
|
|
181
|
+
feature_path = f'{self.model_name}/features'
|
|
182
|
+
features = f[feature_path][mask]
|
|
183
|
+
featuress.append(features)
|
|
184
|
+
|
|
185
|
+
# Store existing clusters
|
|
186
|
+
if clusters is not None:
|
|
187
|
+
clusterss.append(clusters[mask])
|
|
188
|
+
|
|
189
|
+
# Concatenate and normalize
|
|
190
|
+
features = np.concatenate(featuress)
|
|
191
|
+
scaler = StandardScaler()
|
|
192
|
+
self.features = scaler.fit_transform(features)
|
|
193
|
+
|
|
194
|
+
# Store existing clusters state
|
|
195
|
+
if len(clusterss) == len(self.hdf5_paths):
|
|
196
|
+
self.has_clusters = True
|
|
197
|
+
self.total_clusters = np.concatenate(clusterss)
|
|
198
|
+
elif len(clusterss) == 0:
|
|
199
|
+
self.has_clusters = False
|
|
200
|
+
self.total_clusters = None
|
|
201
|
+
else:
|
|
202
|
+
raise RuntimeError(
|
|
203
|
+
f'Cluster count mismatch: {len(clusterss)} vs {len(self.hdf5_paths)}'
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def get_umap_embeddings(self):
|
|
207
|
+
import umap
|
|
208
|
+
"""Get UMAP embeddings (lazy evaluation)"""
|
|
209
|
+
if self.umap_embeddings is not None:
|
|
210
|
+
return self.umap_embeddings
|
|
211
|
+
|
|
212
|
+
reducer = umap.UMAP(n_components=2)
|
|
213
|
+
self.umap_embeddings = reducer.fit_transform(self.features)
|
|
214
|
+
return self.umap_embeddings
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DZI export command for Deep Zoom Image format
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import shutil
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import h5py
|
|
10
|
+
import numpy as np
|
|
11
|
+
from PIL import Image
|
|
12
|
+
|
|
13
|
+
from . import _config, _progress
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DziExportCommand:
|
|
17
|
+
"""
|
|
18
|
+
Export HDF5 patches to DZI (Deep Zoom Image) format
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
cmd = DziExportCommand(jpeg_quality=90, fill_empty=False)
|
|
22
|
+
cmd(hdf5_path='data.h5', output_dir='output', name='slide')
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
jpeg_quality: int = 90,
|
|
27
|
+
fill_empty: bool = False):
|
|
28
|
+
"""
|
|
29
|
+
Initialize DZI export command
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
jpeg_quality: JPEG compression quality (0-100)
|
|
33
|
+
fill_empty: Fill missing tiles with black tiles
|
|
34
|
+
"""
|
|
35
|
+
self.jpeg_quality = jpeg_quality
|
|
36
|
+
self.fill_empty = fill_empty
|
|
37
|
+
|
|
38
|
+
def __call__(self, hdf5_path: str, output_dir: str, name: str) -> dict:
|
|
39
|
+
"""
|
|
40
|
+
Export to DZI format with full pyramid
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
hdf5_path: Path to HDF5 file
|
|
44
|
+
output_dir: Output directory
|
|
45
|
+
name: Base name for DZI files
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
dict: Export metadata
|
|
49
|
+
"""
|
|
50
|
+
output_dir = Path(output_dir)
|
|
51
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
# Read HDF5
|
|
54
|
+
with h5py.File(hdf5_path, 'r') as f:
|
|
55
|
+
patches = f['patches'][:]
|
|
56
|
+
coords = f['coordinates'][:]
|
|
57
|
+
original_width = f['metadata/original_width'][()]
|
|
58
|
+
original_height = f['metadata/original_height'][()]
|
|
59
|
+
tile_size = f['metadata/patch_size'][()]
|
|
60
|
+
|
|
61
|
+
# Validate tile_size (256 or 512 only)
|
|
62
|
+
if tile_size not in [256, 512]:
|
|
63
|
+
raise ValueError(f'Unsupported patch_size: {tile_size}. Only 256 or 512 are supported.')
|
|
64
|
+
|
|
65
|
+
# Calculate grid and levels
|
|
66
|
+
cols = (original_width + tile_size - 1) // tile_size
|
|
67
|
+
rows = (original_height + tile_size - 1) // tile_size
|
|
68
|
+
max_dimension = max(original_width, original_height)
|
|
69
|
+
max_level = math.ceil(math.log2(max_dimension))
|
|
70
|
+
|
|
71
|
+
if _config.verbose:
|
|
72
|
+
print(f'Original size: {original_width}x{original_height}')
|
|
73
|
+
print(f'Tile size: {tile_size}')
|
|
74
|
+
print(f'Grid: {cols}x{rows}')
|
|
75
|
+
print(f'Total patches in HDF5: {len(patches)}')
|
|
76
|
+
print(f'Max zoom level: {max_level} (Level 0 = 1x1, Level {max_level} = original)')
|
|
77
|
+
|
|
78
|
+
coord_to_idx = {(int(x // tile_size), int(y // tile_size)): idx
|
|
79
|
+
for idx, (x, y) in enumerate(coords)}
|
|
80
|
+
|
|
81
|
+
# Setup directories
|
|
82
|
+
dzi_path = output_dir / f'{name}.dzi'
|
|
83
|
+
files_dir = output_dir / f'{name}_files'
|
|
84
|
+
files_dir.mkdir(exist_ok=True)
|
|
85
|
+
|
|
86
|
+
# Create empty tile template for current tile_size
|
|
87
|
+
empty_tile_path = None
|
|
88
|
+
if self.fill_empty:
|
|
89
|
+
empty_tile_path = files_dir / '_empty.jpeg'
|
|
90
|
+
black_img = Image.fromarray(np.zeros((tile_size, tile_size, 3), dtype=np.uint8))
|
|
91
|
+
black_img.save(empty_tile_path, 'JPEG', quality=self.jpeg_quality)
|
|
92
|
+
|
|
93
|
+
# Export max level (original patches from HDF5)
|
|
94
|
+
level_dir = files_dir / str(max_level)
|
|
95
|
+
level_dir.mkdir(exist_ok=True)
|
|
96
|
+
|
|
97
|
+
tq = _progress(range(rows))
|
|
98
|
+
for row in tq:
|
|
99
|
+
tq.set_description(f'Exporting level {max_level}: row {row+1}/{rows}')
|
|
100
|
+
for col in range(cols):
|
|
101
|
+
tile_path = level_dir / f'{col}_{row}.jpeg'
|
|
102
|
+
if (col, row) in coord_to_idx:
|
|
103
|
+
idx = coord_to_idx[(col, row)]
|
|
104
|
+
patch = patches[idx]
|
|
105
|
+
img = Image.fromarray(patch)
|
|
106
|
+
img.save(tile_path, 'JPEG', quality=self.jpeg_quality)
|
|
107
|
+
elif self.fill_empty:
|
|
108
|
+
shutil.copyfile(empty_tile_path, tile_path)
|
|
109
|
+
|
|
110
|
+
# Generate lower levels by downsampling
|
|
111
|
+
for level in range(max_level - 1, -1, -1):
|
|
112
|
+
if _config.verbose:
|
|
113
|
+
print(f'Generating level {level}...')
|
|
114
|
+
self._generate_zoom_level_down(
|
|
115
|
+
files_dir, level, max_level, original_width, original_height,
|
|
116
|
+
tile_size, empty_tile_path
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Generate DZI XML
|
|
120
|
+
self._generate_dzi_xml(dzi_path, original_width, original_height, tile_size)
|
|
121
|
+
|
|
122
|
+
if _config.verbose:
|
|
123
|
+
print(f'DZI export complete: {dzi_path}')
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
'dzi_path': str(dzi_path),
|
|
127
|
+
'max_level': max_level,
|
|
128
|
+
'tile_size': tile_size,
|
|
129
|
+
'grid': f'{cols}x{rows}'
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def _generate_zoom_level_down(self,
|
|
133
|
+
files_dir: Path,
|
|
134
|
+
curr_level: int,
|
|
135
|
+
max_level: int,
|
|
136
|
+
original_width: int,
|
|
137
|
+
original_height: int,
|
|
138
|
+
tile_size: int,
|
|
139
|
+
empty_tile_path: Path | None):
|
|
140
|
+
"""Generate a zoom level by downsampling from the higher level"""
|
|
141
|
+
src_level = curr_level + 1
|
|
142
|
+
src_dir = files_dir / str(src_level)
|
|
143
|
+
curr_dir = files_dir / str(curr_level)
|
|
144
|
+
curr_dir.mkdir(exist_ok=True)
|
|
145
|
+
|
|
146
|
+
# Calculate dimensions at each level
|
|
147
|
+
curr_scale = 2 ** (max_level - curr_level)
|
|
148
|
+
curr_width = math.ceil(original_width / curr_scale)
|
|
149
|
+
curr_height = math.ceil(original_height / curr_scale)
|
|
150
|
+
curr_cols = math.ceil(curr_width / tile_size)
|
|
151
|
+
curr_rows = math.ceil(curr_height / tile_size)
|
|
152
|
+
|
|
153
|
+
src_scale = 2 ** (max_level - src_level)
|
|
154
|
+
src_width = math.ceil(original_width / src_scale)
|
|
155
|
+
src_height = math.ceil(original_height / src_scale)
|
|
156
|
+
src_cols = math.ceil(src_width / tile_size)
|
|
157
|
+
src_rows = math.ceil(src_height / tile_size)
|
|
158
|
+
|
|
159
|
+
tq = _progress(range(curr_rows))
|
|
160
|
+
for row in tq:
|
|
161
|
+
for col in range(curr_cols):
|
|
162
|
+
# Combine 4 tiles from source level
|
|
163
|
+
combined = np.zeros((tile_size * 2, tile_size * 2, 3), dtype=np.uint8)
|
|
164
|
+
has_any_tile = False
|
|
165
|
+
|
|
166
|
+
for dy in range(2):
|
|
167
|
+
for dx in range(2):
|
|
168
|
+
src_col = col * 2 + dx
|
|
169
|
+
src_row = row * 2 + dy
|
|
170
|
+
|
|
171
|
+
if src_col < src_cols and src_row < src_rows:
|
|
172
|
+
src_path = src_dir / f'{src_col}_{src_row}.jpeg'
|
|
173
|
+
if src_path.exists():
|
|
174
|
+
src_img = Image.open(src_path)
|
|
175
|
+
src_array = np.array(src_img)
|
|
176
|
+
h, w = src_array.shape[:2]
|
|
177
|
+
combined[dy*tile_size:dy*tile_size+h,
|
|
178
|
+
dx*tile_size:dx*tile_size+w] = src_array
|
|
179
|
+
has_any_tile = True
|
|
180
|
+
|
|
181
|
+
tile_path = curr_dir / f'{col}_{row}.jpeg'
|
|
182
|
+
if has_any_tile:
|
|
183
|
+
combined_img = Image.fromarray(combined)
|
|
184
|
+
downsampled = combined_img.resize((tile_size, tile_size), Image.LANCZOS)
|
|
185
|
+
downsampled.save(tile_path, 'JPEG', quality=self.jpeg_quality)
|
|
186
|
+
elif self.fill_empty and empty_tile_path:
|
|
187
|
+
shutil.copyfile(empty_tile_path, tile_path)
|
|
188
|
+
|
|
189
|
+
tq.set_description(f'Generating level {curr_level}: row {row+1}/{curr_rows}')
|
|
190
|
+
|
|
191
|
+
def _generate_dzi_xml(self, dzi_path: Path, width: int, height: int, tile_size: int):
|
|
192
|
+
"""Generate DZI XML file"""
|
|
193
|
+
dzi_content = f'''<?xml version="1.0" encoding="utf-8"?>
|
|
194
|
+
<Image xmlns="http://schemas.microsoft.com/deepzoom/2008"
|
|
195
|
+
Format="jpeg"
|
|
196
|
+
Overlap="0"
|
|
197
|
+
TileSize="{tile_size}">
|
|
198
|
+
<Size Width="{width}" Height="{height}"/>
|
|
199
|
+
</Image>
|
|
200
|
+
'''
|
|
201
|
+
with open(dzi_path, 'w', encoding='utf-8') as f:
|
|
202
|
+
f.write(dzi_content)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Patch embedding extraction command
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import gc
|
|
6
|
+
|
|
7
|
+
import h5py
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from ..models import create_model
|
|
13
|
+
from ..utils.helpers import safe_del
|
|
14
|
+
from . import _config, _get, _progress
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PatchEmbeddingResult(BaseModel):
|
|
18
|
+
"""Result of patch embedding extraction"""
|
|
19
|
+
feature_dim: int = 0
|
|
20
|
+
patch_count: int = 0
|
|
21
|
+
model: str = ''
|
|
22
|
+
with_latent: bool = False
|
|
23
|
+
skipped: bool = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PatchEmbeddingCommand:
|
|
27
|
+
"""
|
|
28
|
+
Extract embeddings from patches using foundation models
|
|
29
|
+
|
|
30
|
+
Usage:
|
|
31
|
+
# Set global config once
|
|
32
|
+
commands.set_default_model('gigapath')
|
|
33
|
+
commands.set_default_device('cuda')
|
|
34
|
+
|
|
35
|
+
# Create and run command
|
|
36
|
+
cmd = PatchEmbeddingCommand(batch_size=256, with_latent=False)
|
|
37
|
+
result = cmd(hdf5_path='data.h5')
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
batch_size: int = 256,
|
|
42
|
+
with_latent: bool = False,
|
|
43
|
+
overwrite: bool = False,
|
|
44
|
+
model_name: str | None = None,
|
|
45
|
+
device: str | None = None):
|
|
46
|
+
"""
|
|
47
|
+
Initialize patch embedding extractor
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch_size: Batch size for inference
|
|
51
|
+
with_latent: Whether to extract latent features
|
|
52
|
+
overwrite: Whether to overwrite existing features
|
|
53
|
+
model_name: Model name (None to use global default)
|
|
54
|
+
device: Device (None to use global default)
|
|
55
|
+
|
|
56
|
+
Note:
|
|
57
|
+
progress and verbose are controlled by global config
|
|
58
|
+
"""
|
|
59
|
+
self.batch_size = batch_size
|
|
60
|
+
self.with_latent = with_latent
|
|
61
|
+
self.overwrite = overwrite
|
|
62
|
+
self.model_name = _get('model_name', model_name)
|
|
63
|
+
self.device = _get('device', device)
|
|
64
|
+
|
|
65
|
+
# Validate model
|
|
66
|
+
if self.model_name not in ['uni', 'gigapath', 'virchow2']:
|
|
67
|
+
raise ValueError(f'Invalid model: {self.model_name}')
|
|
68
|
+
|
|
69
|
+
# Dataset paths
|
|
70
|
+
self.feature_name = f'{self.model_name}/features'
|
|
71
|
+
self.latent_feature_name = f'{self.model_name}/latent_features'
|
|
72
|
+
|
|
73
|
+
def __call__(self, hdf5_path: str) -> PatchEmbeddingResult:
|
|
74
|
+
"""
|
|
75
|
+
Execute embedding extraction
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
hdf5_path: Path to HDF5 file
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
PatchEmbeddingResult: Result metadata (feature_dim, patch_count, skipped, etc.)
|
|
82
|
+
"""
|
|
83
|
+
# Load model
|
|
84
|
+
model = create_model(self.model_name)
|
|
85
|
+
model = model.eval().to(self.device)
|
|
86
|
+
|
|
87
|
+
# Normalization parameters
|
|
88
|
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
|
|
89
|
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
|
|
90
|
+
|
|
91
|
+
done = False
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
with h5py.File(hdf5_path, 'r+') as f:
|
|
95
|
+
latent_size = model.patch_embed.proj.kernel_size[0]
|
|
96
|
+
|
|
97
|
+
# Check if already exists
|
|
98
|
+
if not self.overwrite:
|
|
99
|
+
if self.with_latent:
|
|
100
|
+
if (self.feature_name in f) and (self.latent_feature_name in f):
|
|
101
|
+
if _config.verbose:
|
|
102
|
+
print('Already extracted. Skipped.')
|
|
103
|
+
return PatchEmbeddingResult(skipped=True)
|
|
104
|
+
if (self.feature_name in f) or (self.latent_feature_name in f):
|
|
105
|
+
raise RuntimeError(
|
|
106
|
+
f'Either {self.feature_name} or {self.latent_feature_name} exists.'
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
if self.feature_name in f:
|
|
110
|
+
if _config.verbose:
|
|
111
|
+
print('Already extracted. Skipped.')
|
|
112
|
+
return PatchEmbeddingResult(skipped=True)
|
|
113
|
+
|
|
114
|
+
# Delete if overwrite
|
|
115
|
+
if self.overwrite:
|
|
116
|
+
safe_del(f, self.feature_name)
|
|
117
|
+
safe_del(f, self.latent_feature_name)
|
|
118
|
+
|
|
119
|
+
# Get patch count
|
|
120
|
+
patch_count = f['metadata/patch_count'][()]
|
|
121
|
+
|
|
122
|
+
# Create batch indices
|
|
123
|
+
batch_idx = [
|
|
124
|
+
(i, min(i + self.batch_size, patch_count))
|
|
125
|
+
for i in range(0, patch_count, self.batch_size)
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
# Create datasets
|
|
129
|
+
f.create_dataset(
|
|
130
|
+
self.feature_name,
|
|
131
|
+
shape=(patch_count, model.num_features),
|
|
132
|
+
dtype=np.float32
|
|
133
|
+
)
|
|
134
|
+
if self.with_latent:
|
|
135
|
+
f.create_dataset(
|
|
136
|
+
self.latent_feature_name,
|
|
137
|
+
shape=(patch_count, latent_size**2, model.num_features),
|
|
138
|
+
dtype=np.float16
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Process batches
|
|
142
|
+
tq = _progress(batch_idx)
|
|
143
|
+
for i0, i1 in tq:
|
|
144
|
+
# Load batch
|
|
145
|
+
x = f['patches'][i0:i1]
|
|
146
|
+
x = (torch.from_numpy(x) / 255).permute(0, 3, 1, 2) # BHWC->BCHW
|
|
147
|
+
x = x.to(self.device)
|
|
148
|
+
x = (x - mean) / std
|
|
149
|
+
|
|
150
|
+
# Forward pass
|
|
151
|
+
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
152
|
+
h_tensor = model.forward_features(x)
|
|
153
|
+
|
|
154
|
+
# Extract features
|
|
155
|
+
h = h_tensor.cpu().detach().numpy() # [B, T+L, H]
|
|
156
|
+
latent_index = h.shape[1] - latent_size**2
|
|
157
|
+
cls_feature = h[:, 0, ...]
|
|
158
|
+
latent_feature = h[:, latent_index:, ...]
|
|
159
|
+
|
|
160
|
+
# Save features
|
|
161
|
+
f[self.feature_name][i0:i1] = cls_feature
|
|
162
|
+
if self.with_latent:
|
|
163
|
+
f[self.latent_feature_name][i0:i1] = latent_feature.astype(np.float16)
|
|
164
|
+
|
|
165
|
+
# Cleanup
|
|
166
|
+
del x, h_tensor
|
|
167
|
+
torch.cuda.empty_cache()
|
|
168
|
+
|
|
169
|
+
tq.set_description(f'Processing {i0}-{i1} (total={patch_count})')
|
|
170
|
+
tq.refresh()
|
|
171
|
+
|
|
172
|
+
if _config.verbose:
|
|
173
|
+
print(f'Embeddings dimension: {f[self.feature_name].shape}')
|
|
174
|
+
|
|
175
|
+
done = True
|
|
176
|
+
|
|
177
|
+
return PatchEmbeddingResult(
|
|
178
|
+
feature_dim=model.num_features,
|
|
179
|
+
patch_count=patch_count,
|
|
180
|
+
model=self.model_name,
|
|
181
|
+
with_latent=self.with_latent
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
finally:
|
|
185
|
+
if done and _config.verbose:
|
|
186
|
+
print(f'Wrote {self.feature_name}')
|
|
187
|
+
elif not done:
|
|
188
|
+
# Cleanup on error
|
|
189
|
+
with h5py.File(hdf5_path, 'a') as f:
|
|
190
|
+
safe_del(f, self.feature_name)
|
|
191
|
+
if self.with_latent:
|
|
192
|
+
safe_del(f, self.latent_feature_name)
|
|
193
|
+
if _config.verbose:
|
|
194
|
+
print(f'ABORTED! Deleted {self.feature_name}')
|
|
195
|
+
|
|
196
|
+
# Cleanup
|
|
197
|
+
del model, mean, std
|
|
198
|
+
torch.cuda.empty_cache()
|
|
199
|
+
gc.collect()
|