wsi-toolbox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,214 @@
1
+ """
2
+ Clustering command for WSI features
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from pydantic import BaseModel
8
+ from sklearn.preprocessing import StandardScaler
9
+
10
+ from ..utils.analysis import leiden_cluster
11
+ from . import _config, _get
12
+
13
+
14
+ class ClusteringResult(BaseModel):
15
+ """Result of clustering operation"""
16
+ cluster_count: int
17
+ feature_count: int
18
+ target_path: str
19
+ skipped: bool = False
20
+
21
+
22
+ class ClusteringCommand:
23
+ """
24
+ Perform clustering on extracted features
25
+
26
+ Usage:
27
+ # Set global config once
28
+ commands.set_default_model('gigapath')
29
+
30
+ # Create and run command
31
+ cmd = ClusteringCommand(resolution=1.0)
32
+ result = cmd(hdf5_paths=['data1.h5', 'data2.h5'])
33
+ """
34
+
35
+ def __init__(self,
36
+ resolution: float = 1.0,
37
+ cluster_name: str = '',
38
+ cluster_filter: list[int] | None = None,
39
+ use_umap: bool = False,
40
+ overwrite: bool = False,
41
+ model_name: str | None = None):
42
+ """
43
+ Initialize clustering command
44
+
45
+ Args:
46
+ resolution: Leiden clustering resolution
47
+ cluster_name: Name for multi-file clustering
48
+ cluster_filter: Filter to specific clusters for sub-clustering
49
+ use_umap: Whether to use UMAP embeddings for clustering
50
+ overwrite: Whether to overwrite existing clusters
51
+ model_name: Model name (None to use global default)
52
+ """
53
+ self.resolution = resolution
54
+ self.cluster_name = cluster_name
55
+ self.cluster_filter = cluster_filter or []
56
+ self.use_umap = use_umap
57
+ self.overwrite = overwrite
58
+ self.model_name = _get('model_name', model_name)
59
+
60
+ # Validate model
61
+ if self.model_name not in ['uni', 'gigapath', 'virchow2']:
62
+ raise ValueError(f'Invalid model: {self.model_name}')
63
+
64
+ self.sub_clustering = len(self.cluster_filter) > 0
65
+
66
+ # Internal state (initialized in __call__)
67
+ self.hdf5_paths = []
68
+ self.masks = []
69
+ self.features = None
70
+ self.total_clusters = None
71
+ self.umap_embeddings = None
72
+
73
+ def __call__(self, hdf5_paths: str | list[str]) -> ClusteringResult:
74
+ """
75
+ Execute clustering
76
+
77
+ Args:
78
+ hdf5_paths: Single HDF5 path or list of paths
79
+
80
+ Returns:
81
+ ClusteringResult: Result metadata (cluster_count, etc.)
82
+ """
83
+ # Normalize to list
84
+ if isinstance(hdf5_paths, str):
85
+ hdf5_paths = [hdf5_paths]
86
+
87
+ self.hdf5_paths = hdf5_paths
88
+ multi = len(hdf5_paths) > 1
89
+
90
+ # Validate multi-file clustering
91
+ if multi and not self.cluster_name:
92
+ raise RuntimeError('Multiple files provided but cluster_name was not specified.')
93
+
94
+ # Determine cluster path
95
+ if multi:
96
+ clusters_path = f'{self.model_name}/clusters_{self.cluster_name}'
97
+ else:
98
+ clusters_path = f'{self.model_name}/clusters'
99
+
100
+ # Load features
101
+ self._load_features(clusters_path)
102
+
103
+ # Check if already exists
104
+ if not self.sub_clustering and hasattr(self, 'has_clusters') and self.has_clusters and not self.overwrite:
105
+ if _config.verbose:
106
+ print('Skip clustering (already exists)')
107
+ return ClusteringResult(
108
+ cluster_count=len(np.unique(self.total_clusters)),
109
+ feature_count=len(self.features),
110
+ target_path=clusters_path,
111
+ skipped=True
112
+ )
113
+
114
+ # Perform clustering
115
+ self.total_clusters = leiden_cluster(
116
+ self.features,
117
+ umap_emb_func=self.get_umap_embeddings if self.use_umap else None,
118
+ resolution=self.resolution,
119
+ progress=_config.progress
120
+ )
121
+
122
+ # Write results
123
+ target_path = clusters_path
124
+ if self.sub_clustering:
125
+ suffix = '_sub' + '-'.join(map(str, self.cluster_filter))
126
+ target_path = target_path + suffix
127
+
128
+ if _config.verbose:
129
+ print(f'Writing to {target_path}')
130
+
131
+ cursor = 0
132
+ for hdf5_path, mask in zip(self.hdf5_paths, self.masks):
133
+ count = np.sum(mask)
134
+ clusters = self.total_clusters[cursor:cursor + count]
135
+ cursor += count
136
+
137
+ with h5py.File(hdf5_path, 'a') as f:
138
+ if target_path in f:
139
+ del f[target_path]
140
+
141
+ # Fill with -1 for filtered patches
142
+ full_clusters = np.full(len(mask), -1, dtype=clusters.dtype)
143
+ full_clusters[mask] = clusters
144
+ f.create_dataset(target_path, data=full_clusters)
145
+
146
+ cluster_count = len(np.unique(self.total_clusters))
147
+
148
+ return ClusteringResult(
149
+ cluster_count=cluster_count,
150
+ feature_count=len(self.features),
151
+ target_path=target_path
152
+ )
153
+
154
+ def _load_features(self, clusters_path: str):
155
+ """Load features from HDF5 files"""
156
+ featuress = []
157
+ clusterss = []
158
+ self.masks = []
159
+
160
+ for hdf5_path in self.hdf5_paths:
161
+ with h5py.File(hdf5_path, 'r') as f:
162
+ patch_count = f['metadata/patch_count'][()]
163
+
164
+ # Check existing clusters
165
+ if clusters_path in f:
166
+ clusters = f[clusters_path][:]
167
+ else:
168
+ clusters = None
169
+
170
+ # Create mask
171
+ if self.cluster_filter:
172
+ if clusters is None:
173
+ raise RuntimeError('Sub-clustering requires pre-computed clusters')
174
+ mask = np.isin(clusters, self.cluster_filter)
175
+ else:
176
+ mask = np.ones(patch_count, dtype=bool)
177
+
178
+ self.masks.append(mask)
179
+
180
+ # Load features
181
+ feature_path = f'{self.model_name}/features'
182
+ features = f[feature_path][mask]
183
+ featuress.append(features)
184
+
185
+ # Store existing clusters
186
+ if clusters is not None:
187
+ clusterss.append(clusters[mask])
188
+
189
+ # Concatenate and normalize
190
+ features = np.concatenate(featuress)
191
+ scaler = StandardScaler()
192
+ self.features = scaler.fit_transform(features)
193
+
194
+ # Store existing clusters state
195
+ if len(clusterss) == len(self.hdf5_paths):
196
+ self.has_clusters = True
197
+ self.total_clusters = np.concatenate(clusterss)
198
+ elif len(clusterss) == 0:
199
+ self.has_clusters = False
200
+ self.total_clusters = None
201
+ else:
202
+ raise RuntimeError(
203
+ f'Cluster count mismatch: {len(clusterss)} vs {len(self.hdf5_paths)}'
204
+ )
205
+
206
+ def get_umap_embeddings(self):
207
+ import umap
208
+ """Get UMAP embeddings (lazy evaluation)"""
209
+ if self.umap_embeddings is not None:
210
+ return self.umap_embeddings
211
+
212
+ reducer = umap.UMAP(n_components=2)
213
+ self.umap_embeddings = reducer.fit_transform(self.features)
214
+ return self.umap_embeddings
@@ -0,0 +1,202 @@
1
+ """
2
+ DZI export command for Deep Zoom Image format
3
+ """
4
+
5
+ import math
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ import h5py
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ from . import _config, _progress
14
+
15
+
16
+ class DziExportCommand:
17
+ """
18
+ Export HDF5 patches to DZI (Deep Zoom Image) format
19
+
20
+ Usage:
21
+ cmd = DziExportCommand(jpeg_quality=90, fill_empty=False)
22
+ cmd(hdf5_path='data.h5', output_dir='output', name='slide')
23
+ """
24
+
25
+ def __init__(self,
26
+ jpeg_quality: int = 90,
27
+ fill_empty: bool = False):
28
+ """
29
+ Initialize DZI export command
30
+
31
+ Args:
32
+ jpeg_quality: JPEG compression quality (0-100)
33
+ fill_empty: Fill missing tiles with black tiles
34
+ """
35
+ self.jpeg_quality = jpeg_quality
36
+ self.fill_empty = fill_empty
37
+
38
+ def __call__(self, hdf5_path: str, output_dir: str, name: str) -> dict:
39
+ """
40
+ Export to DZI format with full pyramid
41
+
42
+ Args:
43
+ hdf5_path: Path to HDF5 file
44
+ output_dir: Output directory
45
+ name: Base name for DZI files
46
+
47
+ Returns:
48
+ dict: Export metadata
49
+ """
50
+ output_dir = Path(output_dir)
51
+ output_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ # Read HDF5
54
+ with h5py.File(hdf5_path, 'r') as f:
55
+ patches = f['patches'][:]
56
+ coords = f['coordinates'][:]
57
+ original_width = f['metadata/original_width'][()]
58
+ original_height = f['metadata/original_height'][()]
59
+ tile_size = f['metadata/patch_size'][()]
60
+
61
+ # Validate tile_size (256 or 512 only)
62
+ if tile_size not in [256, 512]:
63
+ raise ValueError(f'Unsupported patch_size: {tile_size}. Only 256 or 512 are supported.')
64
+
65
+ # Calculate grid and levels
66
+ cols = (original_width + tile_size - 1) // tile_size
67
+ rows = (original_height + tile_size - 1) // tile_size
68
+ max_dimension = max(original_width, original_height)
69
+ max_level = math.ceil(math.log2(max_dimension))
70
+
71
+ if _config.verbose:
72
+ print(f'Original size: {original_width}x{original_height}')
73
+ print(f'Tile size: {tile_size}')
74
+ print(f'Grid: {cols}x{rows}')
75
+ print(f'Total patches in HDF5: {len(patches)}')
76
+ print(f'Max zoom level: {max_level} (Level 0 = 1x1, Level {max_level} = original)')
77
+
78
+ coord_to_idx = {(int(x // tile_size), int(y // tile_size)): idx
79
+ for idx, (x, y) in enumerate(coords)}
80
+
81
+ # Setup directories
82
+ dzi_path = output_dir / f'{name}.dzi'
83
+ files_dir = output_dir / f'{name}_files'
84
+ files_dir.mkdir(exist_ok=True)
85
+
86
+ # Create empty tile template for current tile_size
87
+ empty_tile_path = None
88
+ if self.fill_empty:
89
+ empty_tile_path = files_dir / '_empty.jpeg'
90
+ black_img = Image.fromarray(np.zeros((tile_size, tile_size, 3), dtype=np.uint8))
91
+ black_img.save(empty_tile_path, 'JPEG', quality=self.jpeg_quality)
92
+
93
+ # Export max level (original patches from HDF5)
94
+ level_dir = files_dir / str(max_level)
95
+ level_dir.mkdir(exist_ok=True)
96
+
97
+ tq = _progress(range(rows))
98
+ for row in tq:
99
+ tq.set_description(f'Exporting level {max_level}: row {row+1}/{rows}')
100
+ for col in range(cols):
101
+ tile_path = level_dir / f'{col}_{row}.jpeg'
102
+ if (col, row) in coord_to_idx:
103
+ idx = coord_to_idx[(col, row)]
104
+ patch = patches[idx]
105
+ img = Image.fromarray(patch)
106
+ img.save(tile_path, 'JPEG', quality=self.jpeg_quality)
107
+ elif self.fill_empty:
108
+ shutil.copyfile(empty_tile_path, tile_path)
109
+
110
+ # Generate lower levels by downsampling
111
+ for level in range(max_level - 1, -1, -1):
112
+ if _config.verbose:
113
+ print(f'Generating level {level}...')
114
+ self._generate_zoom_level_down(
115
+ files_dir, level, max_level, original_width, original_height,
116
+ tile_size, empty_tile_path
117
+ )
118
+
119
+ # Generate DZI XML
120
+ self._generate_dzi_xml(dzi_path, original_width, original_height, tile_size)
121
+
122
+ if _config.verbose:
123
+ print(f'DZI export complete: {dzi_path}')
124
+
125
+ return {
126
+ 'dzi_path': str(dzi_path),
127
+ 'max_level': max_level,
128
+ 'tile_size': tile_size,
129
+ 'grid': f'{cols}x{rows}'
130
+ }
131
+
132
+ def _generate_zoom_level_down(self,
133
+ files_dir: Path,
134
+ curr_level: int,
135
+ max_level: int,
136
+ original_width: int,
137
+ original_height: int,
138
+ tile_size: int,
139
+ empty_tile_path: Path | None):
140
+ """Generate a zoom level by downsampling from the higher level"""
141
+ src_level = curr_level + 1
142
+ src_dir = files_dir / str(src_level)
143
+ curr_dir = files_dir / str(curr_level)
144
+ curr_dir.mkdir(exist_ok=True)
145
+
146
+ # Calculate dimensions at each level
147
+ curr_scale = 2 ** (max_level - curr_level)
148
+ curr_width = math.ceil(original_width / curr_scale)
149
+ curr_height = math.ceil(original_height / curr_scale)
150
+ curr_cols = math.ceil(curr_width / tile_size)
151
+ curr_rows = math.ceil(curr_height / tile_size)
152
+
153
+ src_scale = 2 ** (max_level - src_level)
154
+ src_width = math.ceil(original_width / src_scale)
155
+ src_height = math.ceil(original_height / src_scale)
156
+ src_cols = math.ceil(src_width / tile_size)
157
+ src_rows = math.ceil(src_height / tile_size)
158
+
159
+ tq = _progress(range(curr_rows))
160
+ for row in tq:
161
+ for col in range(curr_cols):
162
+ # Combine 4 tiles from source level
163
+ combined = np.zeros((tile_size * 2, tile_size * 2, 3), dtype=np.uint8)
164
+ has_any_tile = False
165
+
166
+ for dy in range(2):
167
+ for dx in range(2):
168
+ src_col = col * 2 + dx
169
+ src_row = row * 2 + dy
170
+
171
+ if src_col < src_cols and src_row < src_rows:
172
+ src_path = src_dir / f'{src_col}_{src_row}.jpeg'
173
+ if src_path.exists():
174
+ src_img = Image.open(src_path)
175
+ src_array = np.array(src_img)
176
+ h, w = src_array.shape[:2]
177
+ combined[dy*tile_size:dy*tile_size+h,
178
+ dx*tile_size:dx*tile_size+w] = src_array
179
+ has_any_tile = True
180
+
181
+ tile_path = curr_dir / f'{col}_{row}.jpeg'
182
+ if has_any_tile:
183
+ combined_img = Image.fromarray(combined)
184
+ downsampled = combined_img.resize((tile_size, tile_size), Image.LANCZOS)
185
+ downsampled.save(tile_path, 'JPEG', quality=self.jpeg_quality)
186
+ elif self.fill_empty and empty_tile_path:
187
+ shutil.copyfile(empty_tile_path, tile_path)
188
+
189
+ tq.set_description(f'Generating level {curr_level}: row {row+1}/{curr_rows}')
190
+
191
+ def _generate_dzi_xml(self, dzi_path: Path, width: int, height: int, tile_size: int):
192
+ """Generate DZI XML file"""
193
+ dzi_content = f'''<?xml version="1.0" encoding="utf-8"?>
194
+ <Image xmlns="http://schemas.microsoft.com/deepzoom/2008"
195
+ Format="jpeg"
196
+ Overlap="0"
197
+ TileSize="{tile_size}">
198
+ <Size Width="{width}" Height="{height}"/>
199
+ </Image>
200
+ '''
201
+ with open(dzi_path, 'w', encoding='utf-8') as f:
202
+ f.write(dzi_content)
@@ -0,0 +1,199 @@
1
+ """
2
+ Patch embedding extraction command
3
+ """
4
+
5
+ import gc
6
+
7
+ import h5py
8
+ import numpy as np
9
+ import torch
10
+ from pydantic import BaseModel
11
+
12
+ from ..models import create_model
13
+ from ..utils.helpers import safe_del
14
+ from . import _config, _get, _progress
15
+
16
+
17
+ class PatchEmbeddingResult(BaseModel):
18
+ """Result of patch embedding extraction"""
19
+ feature_dim: int = 0
20
+ patch_count: int = 0
21
+ model: str = ''
22
+ with_latent: bool = False
23
+ skipped: bool = False
24
+
25
+
26
+ class PatchEmbeddingCommand:
27
+ """
28
+ Extract embeddings from patches using foundation models
29
+
30
+ Usage:
31
+ # Set global config once
32
+ commands.set_default_model('gigapath')
33
+ commands.set_default_device('cuda')
34
+
35
+ # Create and run command
36
+ cmd = PatchEmbeddingCommand(batch_size=256, with_latent=False)
37
+ result = cmd(hdf5_path='data.h5')
38
+ """
39
+
40
+ def __init__(self,
41
+ batch_size: int = 256,
42
+ with_latent: bool = False,
43
+ overwrite: bool = False,
44
+ model_name: str | None = None,
45
+ device: str | None = None):
46
+ """
47
+ Initialize patch embedding extractor
48
+
49
+ Args:
50
+ batch_size: Batch size for inference
51
+ with_latent: Whether to extract latent features
52
+ overwrite: Whether to overwrite existing features
53
+ model_name: Model name (None to use global default)
54
+ device: Device (None to use global default)
55
+
56
+ Note:
57
+ progress and verbose are controlled by global config
58
+ """
59
+ self.batch_size = batch_size
60
+ self.with_latent = with_latent
61
+ self.overwrite = overwrite
62
+ self.model_name = _get('model_name', model_name)
63
+ self.device = _get('device', device)
64
+
65
+ # Validate model
66
+ if self.model_name not in ['uni', 'gigapath', 'virchow2']:
67
+ raise ValueError(f'Invalid model: {self.model_name}')
68
+
69
+ # Dataset paths
70
+ self.feature_name = f'{self.model_name}/features'
71
+ self.latent_feature_name = f'{self.model_name}/latent_features'
72
+
73
+ def __call__(self, hdf5_path: str) -> PatchEmbeddingResult:
74
+ """
75
+ Execute embedding extraction
76
+
77
+ Args:
78
+ hdf5_path: Path to HDF5 file
79
+
80
+ Returns:
81
+ PatchEmbeddingResult: Result metadata (feature_dim, patch_count, skipped, etc.)
82
+ """
83
+ # Load model
84
+ model = create_model(self.model_name)
85
+ model = model.eval().to(self.device)
86
+
87
+ # Normalization parameters
88
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
89
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
90
+
91
+ done = False
92
+
93
+ try:
94
+ with h5py.File(hdf5_path, 'r+') as f:
95
+ latent_size = model.patch_embed.proj.kernel_size[0]
96
+
97
+ # Check if already exists
98
+ if not self.overwrite:
99
+ if self.with_latent:
100
+ if (self.feature_name in f) and (self.latent_feature_name in f):
101
+ if _config.verbose:
102
+ print('Already extracted. Skipped.')
103
+ return PatchEmbeddingResult(skipped=True)
104
+ if (self.feature_name in f) or (self.latent_feature_name in f):
105
+ raise RuntimeError(
106
+ f'Either {self.feature_name} or {self.latent_feature_name} exists.'
107
+ )
108
+ else:
109
+ if self.feature_name in f:
110
+ if _config.verbose:
111
+ print('Already extracted. Skipped.')
112
+ return PatchEmbeddingResult(skipped=True)
113
+
114
+ # Delete if overwrite
115
+ if self.overwrite:
116
+ safe_del(f, self.feature_name)
117
+ safe_del(f, self.latent_feature_name)
118
+
119
+ # Get patch count
120
+ patch_count = f['metadata/patch_count'][()]
121
+
122
+ # Create batch indices
123
+ batch_idx = [
124
+ (i, min(i + self.batch_size, patch_count))
125
+ for i in range(0, patch_count, self.batch_size)
126
+ ]
127
+
128
+ # Create datasets
129
+ f.create_dataset(
130
+ self.feature_name,
131
+ shape=(patch_count, model.num_features),
132
+ dtype=np.float32
133
+ )
134
+ if self.with_latent:
135
+ f.create_dataset(
136
+ self.latent_feature_name,
137
+ shape=(patch_count, latent_size**2, model.num_features),
138
+ dtype=np.float16
139
+ )
140
+
141
+ # Process batches
142
+ tq = _progress(batch_idx)
143
+ for i0, i1 in tq:
144
+ # Load batch
145
+ x = f['patches'][i0:i1]
146
+ x = (torch.from_numpy(x) / 255).permute(0, 3, 1, 2) # BHWC->BCHW
147
+ x = x.to(self.device)
148
+ x = (x - mean) / std
149
+
150
+ # Forward pass
151
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
152
+ h_tensor = model.forward_features(x)
153
+
154
+ # Extract features
155
+ h = h_tensor.cpu().detach().numpy() # [B, T+L, H]
156
+ latent_index = h.shape[1] - latent_size**2
157
+ cls_feature = h[:, 0, ...]
158
+ latent_feature = h[:, latent_index:, ...]
159
+
160
+ # Save features
161
+ f[self.feature_name][i0:i1] = cls_feature
162
+ if self.with_latent:
163
+ f[self.latent_feature_name][i0:i1] = latent_feature.astype(np.float16)
164
+
165
+ # Cleanup
166
+ del x, h_tensor
167
+ torch.cuda.empty_cache()
168
+
169
+ tq.set_description(f'Processing {i0}-{i1} (total={patch_count})')
170
+ tq.refresh()
171
+
172
+ if _config.verbose:
173
+ print(f'Embeddings dimension: {f[self.feature_name].shape}')
174
+
175
+ done = True
176
+
177
+ return PatchEmbeddingResult(
178
+ feature_dim=model.num_features,
179
+ patch_count=patch_count,
180
+ model=self.model_name,
181
+ with_latent=self.with_latent
182
+ )
183
+
184
+ finally:
185
+ if done and _config.verbose:
186
+ print(f'Wrote {self.feature_name}')
187
+ elif not done:
188
+ # Cleanup on error
189
+ with h5py.File(hdf5_path, 'a') as f:
190
+ safe_del(f, self.feature_name)
191
+ if self.with_latent:
192
+ safe_del(f, self.latent_feature_name)
193
+ if _config.verbose:
194
+ print(f'ABORTED! Deleted {self.feature_name}')
195
+
196
+ # Cleanup
197
+ del model, mean, std
198
+ torch.cuda.empty_cache()
199
+ gc.collect()